### _Setup_

In [None]:
# Reset memory
%reset -f

In [None]:
# Packages
from typing import Union, List, Tuple, Dict, Any, Optional
import math
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import StratifiedShuffleSplit
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, initializers, optimizers, callbacks
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import optuna

In [None]:
# Data
df = pd.read_csv('data.csv')

### _Functions_

In [None]:
def find_col_index_of_spectra(
    df: pd.DataFrame
) -> int:
    """
    Find the column index where spectral data starts.

    Assumes spectral column names can be converted to float (e.g., "730.5", "731.0").

    Parameters:
        df : Input DataFrame

    Returns:
        Index of the first spectral column, or -1 if not found.
    """
    for idx, col in enumerate(df.columns):
        try:
            float(col)
            return idx
        except (ValueError, TypeError):
            continue
    return -1

def split_train_test(
    df: pd.DataFrame,
    test_variety: str,
    test_season: int       
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Split a DataFrame into one training set and two test sets:

    - Variety test set: Variety == test_variety AND Year == 2024
    - Season test set : Year == test_season 

    The training set excludes all rows that belong to any of the test sets.
    The season test set only includes varieties that are present in the training set.

    Parameters:
        df           : Full pandas DataFrame
        test_variety : Variety used for the test set
        test_season  : Year used for the season test

    Returns:
        df_train        : Training set
        df_test_variety : Test set for specified variety and 2024
        df_test_season  : Test set for specified season (filtered by train varieties)
    """

    # Select test set for the specified variety in year 2024
    df_test_variety = df[
        (df["Variety"] == test_variety) &
        (df["Scan Date Year"] == 2024)
    ]

    # Select test set for the specified season (regardless of variety)
    df_test_season = df[
        df["Scan Date Year"] == test_season
    ]

    # Select training set (exclude test variety and test season)
    df_train = df[
        (df["Variety"] != test_variety) &
        (df["Scan Date Year"] != test_season)
    ]

    # Filter season test set to only include varieties present in training set
    train_varieties = df_train["Variety"].unique()
    df_test_season = df_test_season[
        df_test_season["Variety"].isin(train_varieties)
    ]

    return df_train, df_test_variety, df_test_season

def take_subset(
    df: pd.DataFrame, 
    n_subset: int,
    random_state: int
) -> pd.DataFrame:
    """
    Return a stratified subset of the DataFrame based on 10 Brix bins.

    If n_subset >= len(df), the original DataFrame is returned.

    Parameters:
        df       : Input DataFrame with 'Brix (Position)' column
        n_subset : Desired subset size
        random_state : Random seed for reproducibility

    Returns:
        Subset of df with stratification over 10 quantile bins of Brix
    """
    # If requested subset size exceeds full dataset, return a copy of the full DataFrame
    if n_subset >= len(df):
        return df.copy()

    # Bin the Brix values into 10 quantile-based bins for stratification
    binned = pd.qcut(df["Brix (Position)"], q=10, labels=False, duplicates='drop')

    # Initialize stratified sampler
    splitter = StratifiedShuffleSplit(
        n_splits=1,
        train_size=n_subset,
        random_state=random_state
    )

    # Perform stratified split and extract subset indices
    idx_subset, _ = next(splitter.split(df, binned))

    # Return the stratified subset as a new DataFrame with reset index
    return df.iloc[idx_subset].reset_index(drop=True)

def create_train_val_split(
    df: pd.DataFrame,
    validation_size: float,
    random_state: int
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Split a DataFrame into train and validation sets using stratified sampling
    based on 10 quantile bins of the 'Brix (Position)' column.

    Parameters:
        df              : Input DataFrame
        validation_size : Proportion of validation samples (0 < float < 1)
        random_state    : Seed for reproducibility

    Returns:
        df_train, df_val : Stratified training and validation DataFrames
    """
    # Bin the Brix values into 10 quantile-based bins for stratified splitting
    binned = pd.qcut(df["Brix (Position)"], q=10, labels=False, duplicates="drop")

    # Perform stratified train/validation split based on the binned Brix values
    df_train, df_val = train_test_split(
        df,
        test_size=validation_size,
        random_state=random_state,
        stratify=binned
    )

    # Return splits with reset indices
    return df_train.reset_index(drop=True), df_val.reset_index(drop=True)

def split_x_y(
    df: pd.DataFrame,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Split a DataFrame into X (spectral features) and y (target) arrays.
    Assumes find_col_index_of_spectra() is defined globally and returns the index
    where spectral data starts.

    Parameters:
        df : Input DataFrame containing both metadata and spectral data.

    Returns:
        x : NumPy array of shape (n_samples, n_spectral_features)
        y : NumPy array of shape (n_samples, 1) containing Brix values
    """
    # Identify spectral columns (those that can be cast to float, e.g. wavelengths)
    spectra_cols = list(df.columns[find_col_index_of_spectra(df):])

    # Define the target column
    target_cols = ['Brix (Position)']

    # Extract feature and target arrays
    x = df[spectra_cols].values
    y = df[target_cols].values

    return x, y

def make_loader(
    x: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    shuffle: bool = False
) -> DataLoader:
    """
    Converts numpy arrays into a PyTorch DataLoader for use with ViT-style models.

    Parameters:
        x         : np.ndarray
            Feature array of shape (N, L), where N is the number of samples and L is the sequence length.
        y         : np.ndarray
            Target array of shape (N,) or (N, 1).
        batch_size: int
            Batch size for DataLoader.
        shuffle   : bool
            Whether to shuffle the dataset during iteration (default: False).

    Returns:
        DataLoader : PyTorch DataLoader yielding batches of (x, y) formatted for ViT.
    """
    # Convert input arrays to float tensors
    x_tensor = torch.from_numpy(x).float().unsqueeze(1).unsqueeze(-1)  # Shape: (N, 1, L, 1)
    y_tensor = torch.from_numpy(y).float().view(-1, 1)                 # Shape: (N, 1)

    # Create a dataset and wrap in a DataLoader
    dataset = TensorDataset(x_tensor, y_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return loader

def model_spectransformer(
    spectrum_length: int,
    patch_size: int,
    dim: int,
    depth: int,
    heads: int,
    mlp_dim: int,
    pool: str,
    dropout: float,
    emb_dropout: float,
    device: str = None
) -> nn.Module:
    """
    Build a Vision Transformer (ViT) model for regression on spectral data.

    Parameters:
        spectrum_length : int
            Number of input features in the spectrum (1D input length).
        patch_size      : int
            Patch size to divide the spectrum into. Must divide spectrum_length evenly.
        dim             : int
            Dimensionality of the token embeddings.
        depth           : int
            Number of Transformer encoder blocks.
        heads           : int
            Number of attention heads in each multi-head attention layer.
        mlp_dim         : int
            Hidden dimension of the MLP in each Transformer block.
        pool            : str
            Pooling strategy before final head ('cls' or 'mean').
        dropout         : float
            Dropout rate within attention and MLP blocks.
        emb_dropout     : float
            Dropout rate applied after patch embedding and positional encoding.
        device          : str, optional
            Device to place the model on. Defaults to 'cuda' if available.

    Returns:
        nn.Module
            A configured Vision Transformer model (ViT) for 1D spectral regression.
    """

    # Default to GPU if available
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Instantiate the ViT model configured for 1D spectral input
    model = ViT(
        image_size=(spectrum_length, 1),     # Treat spectrum as 1D image
        patch_size=(patch_size, 1),          # Divide spectrum into vertical patches
        num_classes=1,                       # Regression output (1 value)
        dim=dim,                             # Embedding dimension
        depth=depth,                         # Number of transformer layers
        heads=heads,                         # Attention heads
        mlp_dim=mlp_dim,                     # Feedforward MLP dimension
        pool=pool,                           # Pooling strategy: 'cls' or 'mean'
        dropout=dropout,                     # Dropout in attention/MLP blocks
        emb_dropout=emb_dropout              # Dropout after embedding + position
    ).to(device)

    return model

def train_spectransformer(
    model,
    train_loader,
    val_loader,
    epochs: int,
    learning_rate: float,
    lr_scheduler_patience: int,
    weight_decay: float,
    early_stopping_patience: int
) -> Tuple[float, nn.Module]:
    """
    Train a Spectral Vision Transformer (SpecTransformer) model using RMSE loss and early stopping.

    Parameters:
        model                  : ViT model to be trained.
        train_loader           : PyTorch DataLoader with training data.
        val_loader             : PyTorch DataLoader with validation data (can be None).
        epochs                 : Maximum number of training epochs.
        learning_rate          : Initial learning rate for the optimizer.
        lr_scheduler_patience  : Patience for learning rate scheduler.
        weight_decay           : Weight decay (L2 regularization) for optimizer.
        early_stopping_patience: Epochs to wait before early stopping without improvement.

    Returns:
        best_rmse : Best observed RMSE on the monitored set (validation or training).
        model     : Trained PyTorch model with best observed weights.
    """
    device = next(model.parameters()).device

    # Loss and optimizer setup
    criterion = nn.MSELoss()
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler: Reduce LR on plateau
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=lr_scheduler_patience,
        min_lr=1e-6,
        # verbose=True
    )

    best_rmse = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        # === Train Phase ===
        model.train()
        total_train_loss = 0.0

        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            preds = model(x_batch)
            loss = criterion(preds, y_batch)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item() * x_batch.size(0)

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_rmse = avg_train_loss ** 0.5

        # === Validation Phase ===
        if val_loader is not None:
            model.eval()
            total_val_loss = 0.0

            with torch.no_grad():
                for x_batch, y_batch in val_loader:
                    x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                    preds = model(x_batch)
                    loss = criterion(preds, y_batch)
                    total_val_loss += loss.item() * x_batch.size(0)

            avg_val_loss = total_val_loss / len(val_loader.dataset)
            val_rmse = avg_val_loss ** 0.5

            print(f"Epoch {epoch+1}/{epochs} | Train RMSE: {train_rmse:.4f} | Val RMSE: {val_rmse:.4f}")

            scheduler.step(avg_val_loss)
            current_rmse = val_rmse

        else:
            # No validation set → only monitor training RMSE
            print(f"Epoch {epoch+1}/{epochs} | Train RMSE: {train_rmse:.4f} (no val)")
            scheduler.step(avg_train_loss)
            current_rmse = train_rmse

        # === Early Stopping Check ===
        if current_rmse < best_rmse - 1e-5:
            best_rmse = current_rmse
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    return best_rmse, model

def perform_optuna_hyperparameter_optimization(
    train_loader,
    val_loader,
    spectrum_length: int,
    patch_size_options: list,
    dim_options: list,
    depth_options: list,
    heads_options: list,
    mlp_dim_options: list,
    learning_rate_range: Tuple[float, float],
    dropout_range: Tuple[float, float],
    emb_dropout_range: Tuple[float, float],
    weight_decay_range: Tuple[float, float],
    epochs: int,
    lr_scheduler_patience: int,
    timeout_time: int,
    pooling_type: str = "mean",
    device: torch.device = None
) -> Tuple[optuna.Study, float, Dict[str, Any]]:
    """
    Runs Optuna hyperparameter optimization for SpectraTr model.

    Arguments:
        train_loader, val_loader: prepared DataLoader objects
        spectrum_length: int, length of the spectrum (x_train_data.shape[1])
        patch_size_options: list of ints, possible patch sizes to try
        dim_options, depth_options, heads_options, mlp_dim_options: lists of options
        learning_rate_range, weight_decay_range, dropout_range, emb_dropout_range: search spaces
        epochs, lr_scheduler_patience: training settings
        timeout_time: max time for Optuna study (seconds)
        pooling_type: 'mean' or 'cls'
        device: torch.device (optional)

    Returns:
        study, best_val_rmse, best_params
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def objective(trial):
        patch_size = trial.suggest_categorical("patch_size", patch_size_options)
        dim = trial.suggest_categorical("dim", dim_options)
        depth = trial.suggest_categorical("depth", depth_options)
        heads = trial.suggest_categorical("heads", heads_options)
        mlp_dim = trial.suggest_categorical("mlp_dim", mlp_dim_options)
        lr = trial.suggest_float("learning_rate", learning_rate_range[0], learning_rate_range[1], log=True)
        dropout = trial.suggest_float("dropout", dropout_range[0], dropout_range[1])
        emb_dropout = trial.suggest_float("emb_dropout", emb_dropout_range[0], emb_dropout_range[1])
        weight_decay = trial.suggest_float("weight_decay", weight_decay_range[0], weight_decay_range[1], log=True)
        mlp_dim = dim // 2

        print(f"""
        --- Trial {trial.number} ---
        patch_size   = {patch_size}
        dim          = {dim}
        depth        = {depth}
        heads        = {heads}
        mlp_dim      = {mlp_dim}
        learning_rate= {lr:.2e}
        dropout      = {dropout:.2f}
        emb_dropout  = {emb_dropout:.2f}
        weight_decay = {weight_decay:.2e}
        ---------------------------
        """)


        # Model construction
        model = model_spectransformer(
            spectrum_length=spectrum_length,
            patch_size=patch_size,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            pool=pooling_type,
            dropout=dropout,
            emb_dropout=emb_dropout,
        )

        # Training, returns best validation RMSE
        best_val_rmse, _ = train_spectransformer(
            model,
            train_loader,
            val_loader,
            epochs=epochs,
            learning_rate=lr,
            lr_scheduler_patience=lr_scheduler_patience,
            early_stopping_patience=50,
            weight_decay=weight_decay
        )

        return best_val_rmse

    study = optuna.create_study(direction="minimize")
    study.optimize(objective, timeout=timeout_time, catch=(Exception,))

    return study, float(study.best_value), study.best_params

def test_spectransformer(
    trained_model: nn.Module,
    test_loader: torch.utils.data.DataLoader
) -> Tuple[pd.DataFrame, float, float, float]:
    """
    Evaluate a trained SpectraTransformer on the test set.

    Parameters:
        trained_model : A fully trained SpectraTransformer model.
        test_loader   : DataLoader containing test data.

    Returns:
        df_results            : DataFrame with true and predicted Brix values.
        rmse                  : Root Mean Squared Error.
        r2                    : R² score.
        practical_accuracy    : % of predictions within ±20% of target.
    """
    trained_model.eval()
    device = next(trained_model.parameters()).device

    all_preds, all_targets = [], []

    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            preds = trained_model(x_batch)
            all_preds.append(preds.cpu())
            all_targets.append(y_batch.cpu())

    y_pred = torch.cat(all_preds).numpy().flatten()
    y_true = torch.cat(all_targets).numpy().flatten()

    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    practical_accuracy = np.mean(np.abs(y_pred - y_true) <= 0.2 * y_true) * 100

    print(f"Test RMSE: {rmse:.4f}")
    print(f"Test R²:   {r2:.4f}")
    print(f"% within 20% of true Brix: {practical_accuracy:.2f}%")

    df_results = pd.DataFrame({
        "True": y_true,
        "Predicted": y_pred
    })

    return df_results, rmse, r2, practical_accuracy

In [None]:
def pair(t):
    """
    Ensures the input is returned as a tuple of two elements.

    If the input is already a tuple, it is returned as-is.
    Otherwise, the input is duplicated into a tuple (t, t).

    Parameters:
        t : Any
            Input value or tuple.

    Returns:
        tuple : A 2-element tuple (t, t) or the original tuple if already a tuple.
    """
    return t if isinstance(t, tuple) else (t, t)

def get_sinusoidal_encoding(
    seq_len,
    dim
):
    """
    Generate a sinusoidal positional encoding matrix.

    This function returns a tensor of shape (1, seq_len, dim), where each position
    in the sequence has a unique encoding based on sine and cosine functions.
    These encodings help inject position information into models such as Transformers.

    Parameters:
        seq_len : int
            Length of the sequence (number of time steps).
        dim     : int
            Dimensionality of the embedding space.

    Returns:
        torch.Tensor : Positional encoding of shape (1, seq_len, dim)
    """
    # Position indices (shape: [seq_len, 1])
    position = torch.arange(seq_len).unsqueeze(1)

    # Compute the scaling factors for the sinusoidal frequencies (shape: [dim/2])
    div_term = torch.exp(
        torch.arange(0, dim, 2) * -(np.log(10000.0) / dim)
    )

    # Initialize the positional encoding matrix (shape: [seq_len, dim])
    pe = torch.zeros(seq_len, dim)

    # Apply sine to even indices (0, 2, 4, ...) and cosine to odd indices (1, 3, 5, ...)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # Add batch dimension -> shape: (1, seq_len, dim)
    return pe.unsqueeze(0)

def _build_mlp_head(
    self,
    dim,
    mlp_dim,
    dropout,
    n_layers
):
    """
    Build the final MLP head for regression.

    Depending on the number of layers specified, this function constructs:
    - A simple linear layer (if n_layers == 0)
    - A 2-layer MLP with GELU activation and dropout (if n_layers == 1)

    Parameters:
        dim      : int
            Input dimensionality of the MLP head (typically the Transformer output dimension).
        mlp_dim  : int
            Hidden layer size used when n_layers == 1.
        dropout  : float
            Dropout rate applied after activation.
        n_layers : int
            Number of layers in the MLP head. Must be 0 or 1.

    Returns:
        nn.Sequential : A PyTorch Sequential module representing the MLP head.
    """
    layers = [nn.LayerNorm(dim)]  # Normalize input features

    if n_layers == 0:
        # Simple linear regression head
        layers += [nn.Linear(dim, 1)]
    elif n_layers == 1:
        # 2-layer MLP head with GELU activation and dropout
        layers += [
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, 1)
        ]
    else:
        raise ValueError("mlp_head_layers must be 0 or 1")

    return nn.Sequential(*layers)

class PreNorm(nn.Module):
    """
    Applies Layer Normalization before a given function.

    This wrapper is commonly used in Transformer architectures where
    the input is normalized before being passed to a sub-layer like 
    attention or feedforward blocks.

    Parameters:
        dim : int
            Input feature dimension for LayerNorm.
        fn : nn.Module
            A module (e.g., attention or MLP) that takes the normalized input.

    Forward Arguments:
        x : torch.Tensor
            Input tensor of shape [batch_size, sequence_length, dim]
        **kwargs : dict
            Additional arguments passed to the wrapped function.

    Returns:
        torch.Tensor : Output of the wrapped function after pre-normalization.
    """
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # Apply normalization over last dimension
        self.fn = fn                   # Wrapped function (e.g., Attention or FeedForward)

    def forward(self, x, **kwargs):
        # Normalize input before passing to the wrapped function
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    """
    Position-wise FeedForward layer used in Transformer blocks.

    Applies two linear transformations with a GELU activation in between,
    along with dropout for regularization. The input and output dimensions
    are the same to allow residual connections.

    Parameters:
        dim        : int
            Input and output feature dimension.
        hidden_dim : int
            Hidden layer dimension (expansion size).
        dropout    : float
            Dropout probability applied after each linear layer.

    Forward Arguments:
        x : torch.Tensor
            Input tensor of shape [batch_size, sequence_length, dim].

    Returns:
        torch.Tensor : Output tensor of shape [batch_size, sequence_length, dim].
    """
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),  # Project to higher dimension
            nn.GELU(),                   # Non-linear activation
            nn.Dropout(dropout),         # Dropout for regularization
            nn.Linear(hidden_dim, dim),  # Project back to original dimension
            nn.Dropout(dropout)          # Dropout again
        )

    def forward(self, x):
        return self.net(x)  # Pass input through the feedforward block

class Attention(nn.Module):
    """
    Multi-Head Self-Attention layer used in Transformer encoders.

    Computes attention over sequences using multiple heads in parallel.
    Each head operates on a different learned projection of the input.

    Parameters:
        dim       : int
            Input and output embedding dimension.
        heads     : int
            Number of attention heads.
        dim_head  : int
            Dimensionality of each head's projection.
        dropout   : float
            Dropout probability applied to output projection.

    Forward Arguments:
        x : torch.Tensor
            Input tensor of shape [batch_size, sequence_length, dim].

    Returns:
        torch.Tensor : Output tensor of shape [batch_size, sequence_length, dim].
    """
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()

        inner_dim = dim_head * heads               # Total dimension across all heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5              # Scaling factor for attention logits

        self.attend = nn.Softmax(dim=-1)           # Softmax over attention weights
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)  # Combined Q, K, V projection

        # Output projection: identity if only 1 head and dim_head == dim
        self.to_out = (
            nn.Sequential(
                nn.Linear(inner_dim, dim),
                nn.Dropout(dropout)
            ) if project_out else nn.Identity()
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads  # batch_size, seq_len, _, num_heads

        # Project input to Q, K, V and split them
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 3 tensors: [b, n, heads * dim_head]
        q, k, v = map(
            lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), 
            qkv
        )  # Reshape: [batch, heads, seq_len, dim_head]

        # Compute scaled dot-product attention
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # [b, h, seq_len, seq_len]
        attn = self.attend(dots)  # Apply softmax

        # Weighted sum of values
        out = einsum('b h i j, b h j d -> b h i d', attn, v)  # [b, h, seq_len, dim_head]
        out = rearrange(out, 'b h n d -> b n (h d)')  # Concatenate heads

        return self.to_out(out)  # Final projection

class Transformer(nn.Module):
    """
    Transformer encoder block composed of multiple layers.

    Each layer contains:
        - Multi-head self-attention with residual connection and layer normalization.
        - Feedforward neural network with residual connection and layer normalization.

    Parameters:
        dim       : int
            Input and output feature dimension.
        depth     : int
            Number of Transformer layers (i.e., encoder blocks).
        heads     : int
            Number of attention heads.
        dim_head  : int
            Dimensionality of each attention head.
        mlp_dim   : int
            Hidden layer size in the feedforward network.
        dropout   : float
            Dropout probability applied in attention and feedforward layers.

    Forward Arguments:
        x : torch.Tensor
            Input tensor of shape [batch_size, seq_len, dim].

    Returns:
        torch.Tensor : Output tensor of shape [batch_size, seq_len, dim].
    """
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()

        # Build a list of transformer layers
        self.layers = nn.ModuleList([
            nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ])
            for _ in range(depth)
        ])

    def forward(self, x):
        # Pass input through each attention + feedforward block
        for attn, ff in self.layers:
            x = attn(x) + x  # Residual connection after attention
            x = ff(x) + x    # Residual connection after feedforward
        return x

class ViT(nn.Module):
    """
    Vision Transformer (ViT) model for image-based regression tasks.
    
    This implementation follows the original ViT architecture from "An Image is Worth 16x16 Words"
    but is adapted for regression outputs instead of classification.

    Parameters:
        image_size   : int or tuple
            Height and width of the input image. If int, assumes square image.
        patch_size   : int or tuple  
            Height and width of each image patch. If int, assumes square patches.
        num_classes  : int
            Legacy parameter for compatibility (not used in regression setup).
        dim          : int
            Dimension of the token embeddings and transformer hidden states.
        depth        : int
            Number of transformer encoder layers.
        heads        : int
            Number of attention heads in multi-head attention.
        mlp_dim      : int
            Hidden dimension in the transformer feedforward layers.
        pool         : str
            Pooling strategy: 'cls' (use CLS token) or 'mean' (average all tokens).
        channels     : int, default=1
            Number of input channels (1 for grayscale, 3 for RGB).
        dropout      : float
            Dropout rate in transformer layers.
        emb_dropout  : float
            Dropout rate applied after patch embedding + positional encoding.
        dim_head     : int or None, default=None
            Dimension of each attention head. If None, defaults to dim // heads.
    """
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        pool,
        channels=1,
        dropout,
        emb_dropout,
        dim_head=None
    ):
        super().__init__()
        
        # === Image and patch dimension calculations ===
        image_height, image_width = pair(image_size)  # Convert to (height, width) tuple
        patch_height, patch_width = pair(patch_size)  # Convert to (height, width) tuple

        # Ensure image dimensions are evenly divisible by patch dimensions
        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        # Calculate total number of patches and flattened patch dimension
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width  # Flattened patch size
        
        # Validate pooling strategy
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # === Patch embedding layer ===
        # Converts image patches to token embeddings
        self.to_patch_embedding = nn.Sequential(
            # Rearrange image into flattened patches: (B, C, H, W) -> (B, num_patches, patch_dim)
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            # Project patches to embedding dimension
            nn.Linear(patch_dim, dim),
        )

        # === Positional encoding and CLS token ===
        # Register sinusoidal positional encodings as buffer (not trained)
        # +1 for CLS token position
        self.register_buffer("pos_embedding", get_sinusoidal_encoding(num_patches + 1, dim))
        
        # Learnable CLS token (classification token) - used for final prediction
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        # Dropout applied after adding positional encodings
        self.dropout = nn.Dropout(emb_dropout)

        # === Attention head dimension configuration ===
        # If dim_head not specified, divide embedding dimension equally among heads
        if dim_head is None:
            assert dim % heads == 0, f"dim ({dim}) must be divisible by heads ({heads})"
            dim_head = dim // heads

        # === Transformer encoder ===
        # Stack of transformer encoder layers with self-attention and feedforward
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # === Pooling and output configuration ===
        self.pool = pool  # Store pooling strategy
        self.to_latent = nn.Identity()  # Identity layer (placeholder for potential future use)

        # === Regression head ===
        # Simple MLP head for regression output
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),      # Normalize features before final projection
            nn.Linear(dim, 1)       # Project to single regression output
        )

    def forward(self, img):
        """
        Forward pass through the Vision Transformer.

        Parameters:
            img : torch.Tensor
                Input tensor of shape (B, C, H, W) where:
                - B: batch size
                - C: number of channels  
                - H: image height
                - W: image width

        Returns:
            torch.Tensor : Regression output of shape (B, 1)
        """
        # === Patch embedding ===
        # Convert image to sequence of patch embeddings
        x = self.to_patch_embedding(img)  # Shape: (B, num_patches, dim)
        b, n, _ = x.shape

        # === Add CLS token ===
        # Prepend learnable CLS token to sequence
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)  # Expand for batch
        x = torch.cat((cls_tokens, x), dim=1)  # Shape: (B, num_patches+1, dim)
        
        # === Add positional encoding ===
        # Add sinusoidal position encodings (slice to match sequence length)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)  # Apply dropout after position encoding

        # === Transformer processing ===
        # Pass through stack of transformer encoder layers
        x = self.transformer(x)  # Shape: (B, num_patches+1, dim)

        # === Pooling ===
        # Extract final representation for prediction
        if self.pool == 'mean':
            x = x.mean(dim=1)  # Average all tokens (including CLS)
        else:  # self.pool == 'cls'
            x = x[:, 0]        # Use CLS token only
        # Shape after pooling: (B, dim)

        # === Final projection ===
        x = self.to_latent(x)    # Identity transformation (no-op currently)
        return self.mlp_head(x)  # Project to regression output (B, 1)



### _Parameters_

In [None]:
DF              = df
TEST_VARIETY    = "TestVariety"
TEST_SEASON     = 2025

RANDOM_STATE    = 27
N_SUBSET        = 23690
VALIDATION_SIZE = 0.1

BATCH_SIZE      = 64
SPECTRUM_LENGTH = 1026

PATCH_SIZE_OPTIONS = [n for n in range(10, SPECTRUM_LENGTH) if SPECTRUM_LENGTH % n == 0]
DIM_OPTIONS = [128, 256, 512, 1024]
DEPTH_RANGE = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
HEADS_OPTIONS = [4, 8, 16]
MLP_DIM_OPTIONS = [256, 512, 1024]
DROPOUT_RANGE = (0.01, 0.4)
EMB_DROPOUT_RANGE = (0.01, 0.4)
WEIGHT_DECAY_RANGE = (0, 1e-2)

EARLY_STOPPING_PATIENCE = 50
LR_SCHEDULER_PATIENCE = 25

TRAIN_EPOCHS      = 250
TEST_EPOCHS       = 1000
TIMEOUT_TIME      = 60 * 60 * 72

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### _Run_

In [None]:
# === Split into train and test sets ===
df_train_all, df_test_variety, df_test_season = split_train_test(
    df,
    test_variety=TEST_VARIETY,
    test_season=TEST_SEASON,
)

# === Take subset ===
df_subset = take_subset(
    df_train_all, 
    n_subset=N_SUBSET, 
    random_state=RANDOM_STATE
)

# === Make train/validation split ===
df_train, df_val = create_train_val_split(
    df=df_subset,
    validation_size=VALIDATION_SIZE,
    random_state=RANDOM_STATE
)

# === Convert to x and y arrays ===
x_train_all, y_train_all = split_x_y(
    df_train_all,
)
x_train, y_train = split_x_y(
    df_train,
)
x_val, y_val = split_x_y(
    df_val,
)
x_test_variety, y_train_variety = split_x_y(
    df_test_variety,
)
x_test_season, y_train_season = split_x_y(
    df_test_season,
)

# === Define DataLoaders ===
train_all_loader = make_loader(
    x_train_all,
    y_train_all,
    batch_size=BATCH_SIZE,
    shuffle=True
)
train_loader = make_loader(
    x_train, 
    y_train, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)
val_loader = make_loader(
    x_val,   
    y_val,   
    batch_size=BATCH_SIZE, 
    shuffle=True
)
test_loader_variety = make_loader(
    x_test_variety,       
    y_train_variety,       
    batch_size=BATCH_SIZE, 
    shuffle=False
)
test_loader_season = make_loader(
    x_test_season,       
    y_train_season,       
    batch_size=BATCH_SIZE, 
    shuffle=False
)

# === Run Optuna Hyperparameter Optimization ===
study, best_val_rmse, best_params = perform_optuna_hyperparameter_optimization(
    train_loader=train_loader,
    val_loader=val_loader,
    spectrum_length=SPECTRUM_LENGTH,
    patch_size_options=PATCH_SIZE_OPTIONS,
    dim_options=DIM_OPTIONS,
    depth_options=DEPTH_RANGE,
    heads_options=HEADS_OPTIONS,
    mlp_dim_options=MLP_DIM_OPTIONS,
    learning_rate_range=(0.0001, 0.0001),
    dropout_range=DROPOUT_RANGE,
    emb_dropout_range=EMB_DROPOUT_RANGE,
    weight_decay_range=WEIGHT_DECAY_RANGE,
    epochs=TRAIN_EPOCHS,
    lr_scheduler_patience=LR_SCHEDULER_PATIENCE,
    timeout_time=TIMEOUT_TIME,
    pooling_type='mean'
)

# === Define the model with best parameters ===
model = model_spectransformer(
    spectrum_length = SPECTRUM_LENGTH,
    patch_size      = best_params["patch_size"],
    dim             = best_params["dim"],
    depth           = best_params["depth"],
    heads           = best_params["heads"],
    mlp_dim         = best_params["mlp_dim"],
    pool            = "mean",
    dropout         = best_params["dropout"],
    emb_dropout     = best_params["emb_dropout"],
)

# === Retrain the model on all data ===
_, trained_model = train_spectransformer(
    model=model,
    train_loader=train_all_loader,
    val_loader=None,
    epochs=TEST_EPOCHS,
    learning_rate=best_params["learning_rate"],
    lr_scheduler_patience=LR_SCHEDULER_PATIENCE,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    weight_decay=best_params["weight_decay"]
)


# === Variety test ===
df_results_variety, rmse_variety, r2_variety, acc_variety = test_spectransformer(
    trained_model=trained_model,
    test_loader=test_loader_variety
)

# === Season test ===
df_results_season, rmse_season, r2_season, acc_season = test_spectransformer(
    trained_model=trained_model,
    test_loader=test_loader_season
)

summary = pd.DataFrame({
    "Test Set": ["Variety", "Season"],
    "RMSE": [rmse_variety, rmse_season],
    "R²": [r2_variety, r2_season],
    "% Within 20%": [acc_variety, acc_season]
})

print(summary)


### _Inference Time Analysis_

In [None]:
def get_inference_sample_set(
    df_variety: pd.DataFrame,
    df_season: pd.DataFrame,
    random_state: int,
    sample_size: int = 1000
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Combine two test sets (variety and season), sample rows randomly, and return X and y arrays.

    Parameters:
        df_variety   : DataFrame for variety-based test set
        df_season    : DataFrame for season-based test set
        random_state : Random seed for reproducibility
        sample_size  : Number of rows to sample from combined test set

    Returns:
        x_sample : NumPy array of shape (sample_size, n_features) with spectral features
        y_sample : NumPy array of shape (sample_size,) with corresponding Brix values
    """
    # Combine the two test sets
    df_combined = pd.concat([df_variety, df_season], axis=0)

    # Randomly sample rows from the combined test set
    df_sample = df_combined.sample(
        n=sample_size,
        random_state=random_state
    )

    # Split into X and y arrays
    x_sample, y_sample = split_x_y(df_sample)

    return x_sample, y_sample

def test_spectransformer_inference_time(
    model: nn.Module,
    x_test: np.ndarray
) -> float:
    """
    Measure average one-by-one inference time of a SpectraTransformer model in milliseconds.

    Parameters:
        model   : Trained SpectraTransformer model
        x_test  : Test feature matrix of shape (n_samples, n_spectral_features)

    Returns:
        avg_inference_time_ms : Average inference time per sample in milliseconds
    """
    model.eval()
    device = next(model.parameters()).device
    times = []

    # Convert numpy array to the expected tensor format for ViT
    # x_test shape: (n_samples, n_spectral_features)
    # Need to convert to: (1, 1, n_spectral_features, 1) for each sample
    
    for x in x_test:
        # Convert single sample to tensor with proper shape for ViT
        x_input = torch.from_numpy(x).float().unsqueeze(0).unsqueeze(0).unsqueeze(-1)  # Shape: (1, 1, n_features, 1)
        x_input = x_input.to(device)
        
        start = time.time()
        with torch.no_grad():
            _ = model(x_input).cpu().numpy().flatten()[0]
        end = time.time()
        
        times.append(end - start)

    avg_inference_time_ms = np.mean(times) * 1000
    print(f"Average inference time: {avg_inference_time_ms:.3f} ms/sample")

    return avg_inference_time_ms

In [None]:
# === Create sample set for inference time measurement ===
x_inference_time, y_inference_time = get_inference_sample_set(
    df_test_variety,
    df_test_season,
    random_state=RANDOM_STATE
)

# === Compute the average inference time ===
inference_time = test_spectransformer_inference_time(
    trained_model, 
    x_inference_time
)
