In [1]:
import numpy as np
import polars as pl
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight


import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import os
import joblib
import random
import warnings
warnings.filterwarnings("ignore")

In [2]:
IMU_LEN = 14
THM_TOF_SUMMARIES_LEN = 30
BATCH_SIZE = 64
PAD_PERCENTILE = 95
PAD_LEN = 127
LR_INIT = 5e-4
WD = 3e-3
MIXUP_ALPHA = 0.4
EPOCHS = 200
PATIENCE = 15
SEED = 3126
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from sklearn.metrics import f1_score

class ParticipantVisibleError(Exception):
    """Errors raised here will be shown directly to the competitor."""
    pass


class CompetitionMetric:
    """Hierarchical macro F1 for the CMI 2025 challenge."""
    def __init__(self):
        self.target_gestures = [
            'Above ear - pull hair',
            'Cheek - pinch skin',
            'Eyebrow - pull hair',
            'Eyelash - pull hair',
            'Forehead - pull hairline',
            'Forehead - scratch',
            'Neck - pinch skin',
            'Neck - scratch',
        ]
        self.non_target_gestures = [
            'Write name on leg',
            'Wave hello',
            'Glasses on/off',
            'Text on phone',
            'Write name in air',
            'Feel around in tray and pull out an object',
            'Scratch knee/leg skin',
            'Pull air toward your face',
            'Drink from bottle/cup',
            'Pinch knee/leg skin'
        ]
        self.all_classes = self.target_gestures + self.non_target_gestures

    def calculate_hierarchical_f1(
        self,
        sol: pd.DataFrame,
        sub: pd.DataFrame
    ) -> float:

        # Validate gestures
        invalid_types = {i for i in sub['gesture'].unique() if i not in self.all_classes}
        if invalid_types:
            raise ParticipantVisibleError(
                f"Invalid gesture values in submission: {invalid_types}"
            )

        # Compute binary F1 (Target vs Non-Target)
        y_true_bin = sol['gesture'].isin(self.target_gestures).values
        y_pred_bin = sub['gesture'].isin(self.target_gestures).values
        
        f1_binary = f1_score(y_true_bin, y_pred_bin, pos_label=True, zero_division=0, average='binary')

        # Build multi-class labels for gestures
        y_true_mc = sol['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')
        y_pred_mc = sub['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')

        f1_macro = f1_score(y_true_mc, y_pred_mc, average='macro', zero_division=0)

        return f1_binary, f1_macro, (f1_binary+f1_macro)/2.0

In [4]:
def F1_score(y_val, y_pred, lbl_encoder, choice="weighted"):
    metric = CompetitionMetric()
    y_val  = pd.DataFrame({'id':range(len(y_val)), 
                           'gesture':y_val})
    y_pred = pd.DataFrame({'id':range(len(y_pred)), 
                           'gesture':y_pred})

    ## Convert numeric labels to original descriptions
    y_val["gesture"]  = lbl_encoder.inverse_transform(y_val["gesture"])
    y_pred["gesture"] = lbl_encoder.inverse_transform(y_pred["gesture"])

    ## Computes score
    binary, macro, weighted = metric.calculate_hierarchical_f1(y_val, y_pred)

    ## Returns result
    if choice=="binary": return binary
    elif choice=="macro": return macro
    elif choice=="weighted": return weighted
    else: return (binary, macro, weighted)

In [5]:
def seed_all(seed=3126):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

seed_all()

In [6]:
RAW_DIR = "/kaggle/input/cmi-detect-behavior-with-sensor-data"

label_encoder = joblib.load("/kaggle/input/cmi-label-encoder/label_encoder.joblib")
standard_scaler = joblib.load("/kaggle/input/cmi-custom-tensor-data-v2/StandardScaler.joblib")
X = torch.load("/kaggle/input/cmi-custom-tensor-data-v2/X.pt")
y_int = np.load("/kaggle/input/cmi-custom-tensor-data-v2/y_int.npy")
y_ohe = torch.load("/kaggle/input/cmi-custom-tensor-data-v2/y_ohe.pt")

imu_cols = joblib.load("/kaggle/input/cmi-custom-tensor-data-v2/imu_cols.joblib")
thm_tof_cols = joblib.load("/kaggle/input/cmi-custom-tensor-data-v2/thm_tof_cols.joblib")
final_feature_cols = joblib.load("/kaggle/input/cmi-custom-tensor-data-v2/final_feature_cols.joblib")

In [7]:
print("IMU Features:")
print(imu_cols)
print("\n\nThm+TOF Features:")
print(thm_tof_cols)

IMU Features:
['acc_x', 'acc_y', 'acc_z', 'acc_x_diff', 'acc_y_diff', 'acc_z_diff', 'rot_w', 'rot_x', 'rot_y', 'rot_z', 'acc_mag', 'rot_angle', 'acc_mag_diff', 'rot_angle_diff']


Thm+TOF Features:
['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5', 'thm_1_diff', 'thm_2_diff', 'thm_3_diff', 'thm_4_diff', 'thm_5_diff', 'tof_1_mean', 'tof_1_std', 'tof_1_min', 'tof_1_max', 'tof_2_mean', 'tof_2_std', 'tof_2_min', 'tof_2_max', 'tof_3_mean', 'tof_3_std', 'tof_3_min', 'tof_3_max', 'tof_4_mean', 'tof_4_std', 'tof_4_min', 'tof_4_max', 'tof_5_mean', 'tof_5_std', 'tof_5_min', 'tof_5_max', 'tof_1_v0', 'tof_1_v1', 'tof_1_v2', 'tof_1_v3', 'tof_1_v4', 'tof_1_v5', 'tof_1_v6', 'tof_1_v7', 'tof_1_v8', 'tof_1_v9', 'tof_1_v10', 'tof_1_v11', 'tof_1_v12', 'tof_1_v13', 'tof_1_v14', 'tof_1_v15', 'tof_1_v16', 'tof_1_v17', 'tof_1_v18', 'tof_1_v19', 'tof_1_v20', 'tof_1_v21', 'tof_1_v22', 'tof_1_v23', 'tof_1_v24', 'tof_1_v25', 'tof_1_v26', 'tof_1_v27', 'tof_1_v28', 'tof_1_v29', 'tof_1_v30', 'tof_1_v31', 'tof_1_v32', 't

# Model Setup

In [8]:
class ResidualCoordinateAttention(nn.Module):
    """
    Coordinate Attention adapted for 1D temporal sequences.
    Input: (B, T, C)
    Output: (B, T, C)
    """
    def __init__(self, channels, reduction=8):
        super(ResidualCoordinateAttention, self).__init__()
        self.mid_channels = max(8, channels // reduction)
        self.pos_embed_T = nn.Parameter(torch.randn(1, self.mid_channels, PAD_LEN))
        self.gamma = nn.Parameter(torch.tensor(1.0))  # learnable scaling

        # Temporal pooling to preserve time-dimension structure
        self.conv1 = nn.Conv1d(channels, self.mid_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(self.mid_channels)
        self.act = nn.SiLU()
        
        # Learn attention over time (coordinate)
        self.attn_T = nn.Conv1d(self.mid_channels, 1, kernel_size=1, bias=False)
        self.attn_C = nn.Conv1d(self.mid_channels, channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # (B, T, C) → (B, C, T)
        x_perm = x.permute(0, 2, 1)
        # (B, rC, T)
        f = self.act(self.bn1(self.conv1(x_perm))) 
        f = f + self.pos_embed_T[:, :, :f.shape[-1]]
        # (B, 1, T)
        time_attn = self.sigmoid(self.attn_T(f)) 
        # (B, C, 1)
        channel_attn = self.sigmoid(self.attn_C(f.mean(dim=2, keepdim=True))) 
        
        out = x_perm + self.gamma * (x_perm * time_attn * channel_attn)  # (B, C, T)
        return out.permute(0, 2, 1)  # Back to (B, T, C)

In [9]:
class ResidualCNNBlock(nn.Module):
    """
    Residual CNN Block with Squeeze-and-Excitation (SE)
    Input expected in (batch_size, timesteps, channels) format.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, pool_size=2, drop=0.3):
        super(ResidualCNNBlock, self).__init__()
        # PyTorch Conv1D expects (batch_size, channels, timesteps)

        ## CNN model
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.attention = ResidualCoordinateAttention(out_channels)

        self.shortcut_proj = None
        if in_channels != out_channels:
            self.shortcut_proj = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, padding='same', bias=False),
                nn.BatchNorm1d(out_channels)
            )

        self.relu_final = nn.ReLU(inplace=True)
        if pool_size is not None: 
            self.max_pool = nn.MaxPool1d(pool_size)
        else: 
            self.max_pool = None
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        shortcut = x                                      # (B, T, C_in)
        x_permuted = self.cnn(x.permute(0, 2, 1))         # (B, C_out, T)
        x_attn = self.attention(x_permuted.permute(0, 2, 1)) # (B, T, C_out)

        # Handle shortcut connection
        if self.shortcut_proj:
            shortcut = self.shortcut_proj(shortcut.permute(0, 2, 1)).permute(0, 2, 1)

        # Residual connection
        x = self.relu_final(x_attn + shortcut)
        if self.max_pool is not None:
            x = self.max_pool(x.permute(0, 2, 1)).permute(0, 2, 1) # (B, T, C_out) -> (B, T//pool_size, C_out)
        x = self.dropout(x)
        
        return x

In [10]:
class TOFCompressor(nn.Module):
    def __init__(self):
        super(TOFCompressor, self).__init__()
        self.sensor_cnn = nn.Sequential(
            nn.Conv2d(5, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.SiLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  
            nn.BatchNorm2d(32),
            nn.AdaptiveAvgPool2d((1, 1))  
        )

    def forward(self, x):
        B, T, _ = x.shape
        x = x.view(B*T, 5, 8, 8)  # (B,T,320) --> (B*T, 5, 8, 8)
        x = self.sensor_cnn(x)  # (B*T, 32, 1, 1)
        x = x.view(B, T, 32) # (B, T, 32)
        return x  

In [11]:
class MLPAttention(nn.Module):
    """
    Attention mechanism to weigh the importance of different timesteps.
    Input expected in (batch_size, timesteps, features) format.
    """
    def __init__(self, feature_dim):
        super(MLPAttention, self).__init__()
        self.attn = nn.Sequential(
            nn.Linear(feature_dim, feature_dim//8),
            nn.SiLU(inplace=True),
            nn.Linear(feature_dim//8, 1)
        )

    def forward(self, inputs):
        # inputs shape: (B, T, C)
        score = self.attn(inputs).squeeze(-1) # (B, T)
        weights = F.softmax(score, dim=-1).unsqueeze(-1) # (B, T, 1)
        context = (inputs * weights).sum(dim=1) # (B, T, C) -> (B, C)
        return context

In [12]:
class TriSenseNet(nn.Module):
    def __init__(self, pad_len, imu_dim, thm_tof_summaries_dim, n_classes=18):
        super(TriSenseNet, self).__init__()
        self.imu_dim = imu_dim
        self.thm_tof_summaries_dim = thm_tof_summaries_dim
        
        # --- IMU Branch ---
        # (B, T, IMU dim) --> --> (B, T/4, 128)
        self.imu_branch = nn.Sequential(
            ResidualCNNBlock(imu_dim, 64, kernel_size=3, pool_size=2, drop=0.1), # Output shape: (B, T/2, 64)
            ResidualCNNBlock(64, 64, kernel_size=3, pool_size=None, drop=0.1), # Output shape: (B, T/2, 64)
            ResidualCNNBlock(64, 128, kernel_size=5, pool_size=2, drop=0.1) # Output shape: (B, T/4, 128)
        )

        # --- Thm/TOF Branch ---
        # (B, T, Thm+TOF summaries dim) --> (B, T/4, 128)
        self.thm_tof_branch = nn.Sequential(
            nn.Conv1d(thm_tof_summaries_dim, 64, 3, padding='same', bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),
            nn.Dropout(0.2),

            nn.Conv1d(64, 64, 3, padding='same', bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            
            nn.Conv1d(64, 128, 3, padding='same', bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),
            nn.Dropout(0.2)  
        )

        # --- Raw TOF Branch ---
        # (B, T, TOF raw dim) --> (B, T/4, 128)
        self.tof_raw_branch = nn.Sequential(
            TOFCompressor(), # Output shape: (B, T, 32)
            ResidualCNNBlock(32, 64, kernel_size=3, pool_size=2, drop=0.1), # Output shape: (B, T/2, 64)
            ResidualCNNBlock(64, 64, kernel_size=3, pool_size=None, drop=0.1), # Output shape: (B, T/2, 64)
            ResidualCNNBlock(64, 128, kernel_size=5, pool_size=2, drop=0.1) # Output shape: (B, T/4, 128)
        )


        # --- Merged Branch and Recurrent Layers ---
        # Merged dimension: 128 (IMU) + 128 (Thm + TOF summaries) + 128 (TOF raw features) = 384
        merged_feature_dim = 128 + 128 + 128
        self.lstm = nn.LSTM(merged_feature_dim, hidden_size=128, bidirectional=True, batch_first=True)
        self.gru  = nn.GRU(merged_feature_dim, hidden_size=128, bidirectional=True, batch_first=True)
        
        # Output of bidirectional LSTM/GRU will be 2 * hidden_size
        # (batch_size, timesteps_after_pooling, 2 * 128) = (batch_size, pad_len/4, 256)

        # For x_merged path
        self.gaussian_noise_std = 0.09 
        self.dense = nn.Linear(merged_feature_dim, 16)
        self.elu   = nn.ELU()
        
        # Concatenated features 
        # x_gru: (B, T, 256)
        # x_lstm: (B, T, 256)
        # x_merged: (B, T, 16)
        self.concat_dropout = nn.Dropout(0.4)
        self.attention_layer = MLPAttention(528)

        # --- Classification Head ---
        # After attention, shape is (B, 512)
        self.classifier = nn.Sequential(
            nn.Linear(528, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            
            nn.Linear(256, 128, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )
        self.output_layer = nn.Linear(128, n_classes)
        
    def forward(self, inp):
        x_imu     = inp[:, :, :self.imu_dim] # (B, T, IMU dim)
        x_thm_tof = inp[:, :, self.imu_dim : self.thm_tof_summaries_dim+self.imu_dim] # (B, T, Thm + TOF summaries dim)
        x_tof_raw = inp[:, :, self.thm_tof_summaries_dim+self.imu_dim :] # (B, T, TOF raw dim)
        
        # --- IMU Branch ---
        x_imu = self.imu_branch(x_imu) # (B, T/4, 128)

        # --- Thm/TOF summaries Branch ---
        x_thm_tof = self.thm_tof_branch(x_thm_tof.permute(0, 2, 1)) 
        x_thm_tof = x_thm_tof.permute(0, 2, 1) # (B, T/4, 128)

        # --- TOF raw Branch ---
        x_tof_raw = self.tof_raw_branch(x_tof_raw) # (B, T/4, 128)
        
        # --- Merge Branches ---
        merged = torch.cat([x_imu, x_thm_tof, x_tof_raw], dim=-1) # (B, T/4, 384)

        # --- Recurrent Layers ---
        x_lstm, _ = self.lstm(merged) # (B, T/4, 256)
        x_gru, _  = self.gru(merged)  # (B, T/4, 256)
        
        # x_merged path (gaussian noise)
        if self.training: 
            x_merged = merged + torch.randn_like(merged)*self.gaussian_noise_std
        else:
            x_merged = merged
        x_merged = self.elu(self.dense(x_merged)) # (B, T/4, 16)

        # Concatenate outputs of all three paths
        x = torch.cat([x_lstm, x_gru, x_merged], dim=-1) # (B, T/4, 384*2 + 16) = (B, T/4, 528)
        x = self.concat_dropout(x)

        # Attention layer
        x = self.attention_layer(x) # Output: (B, 528)

        # --- Classification Head ---
        x = self.classifier(x)
        out = self.output_layer(x) # (B, 18)
        
        return out

In [13]:
def init_model_weights(model:nn.Module):
    for module in model.modules():
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
            
        elif isinstance(module, (nn.LSTM, nn.GRU)):
            for name, param in module.named_parameters():
                if 'weight_ih' in name: 
                    nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name: 
                    nn.init.orthogonal_(param.data) 
                elif 'bias_ih' in name or 'bias_hh' in name: 
                    nn.init.constant_(param.data, 0)
                    if 'bias_ih' in name and isinstance(module, nn.LSTM):
                        nn.init.constant_(param.data[module.hidden_size : 2 * module.hidden_size], 1.0)
        
        elif isinstance(module, nn.BatchNorm1d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

In [14]:
class MixupDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, alpha: float = 0.2):
        """
        Args:
            X (np.ndarray): Features (e.g., padded time series data).
                            Expected shape (num_samples, timesteps, features).
            y (np.ndarray): Labels (e.g., one-hot encoded or class indices).
                            Expected shape (num_samples, num_classes) for one-hot,
                            or (num_samples,) for class indices.
            alpha (float): Alpha parameter for the Beta distribution used in Mixup.
        """
        # Convert X and y to PyTorch tensors once
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32 if (y.ndim>1) else (torch.long)) # Use long for class indices
        self.alpha = alpha

    def __len__(self): return len(self.X)

    def __getitem__(self, idx):
        """
        Generates one sample of data for Mixup.
        """
        x, y = self.X[idx], self.y[idx]
        
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
            rand_idx = np.random.randint(0, len(self.X))
            x_rand, y_rand = self.X[rand_idx], self.y[rand_idx]
            
            x = lam * x + (1 - lam) * x_rand
            y = lam * y + (1 - lam) * y_rand
            
        return x, y

In [15]:
def label_smoothing_loss(pred, target, smoothing=0.1):
    """Label smoothing loss"""
    confidence = 1.0 - smoothing
    log_probs = F.log_softmax(pred, dim=-1)
    nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
    nll_loss = nll_loss.squeeze(1)
    smooth_loss = -log_probs.mean(dim=-1)
    loss = confidence * nll_loss + smoothing * smooth_loss
    return loss.mean()

# Data Preparation

In [16]:
from sklearn.model_selection import train_test_split

In [17]:
X_train, X_val, y_train, y_val = train_test_split(
        X.numpy(), y_ohe.numpy(), 
        test_size=0.2, random_state=SEED, stratify=y_int
)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.float32)


cw_vals = compute_class_weight('balanced', 
                               classes=np.arange(len(label_encoder.classes_)),
                               y=y_int)
class_weights = torch.FloatTensor(cw_vals).to(DEVICE)

In [18]:
# Data loaders
train_dataset = MixupDataset(X_train, y_train, alpha=MIXUP_ALPHA)
val_dataset = MixupDataset(X_val, y_val, alpha=0.0)  
    
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

# Model Training

In [19]:
model = TriSenseNet(PAD_LEN, IMU_LEN, THM_TOF_SUMMARIES_LEN).to(DEVICE)
init_model_weights(model)

optimizer = AdamW(model.parameters(), lr=LR_INIT, weight_decay=WD)

steps_per_epoch = len(train_loader)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5*steps_per_epoch)

In [20]:
class F1EarlyStopping:
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_f1 = 0
        self.early_stop = False

    def __call__(self, model, current_f1):
        if current_f1 > self.best_f1:
            self.best_f1 = current_f1
            self.counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            self.counter += 1
            if self.verbose:
                print(f"F1 EarlyStopping: {self.counter}/{self.patience}\n")
            if self.counter >= self.patience:
                self.early_stop = True

In [21]:
def train(model, train_loader, optimizer, class_weights, lr_scheduler=None):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
        
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
                    
        logits = model(batch_x)
                    
        # Handle mixup targets
        if batch_y.ndim == 2 and batch_y.shape[1] > 1:  # MixUp or one-hot
            sample_weights = torch.sum(batch_y * class_weights.unsqueeze(0), dim=1)
            log_probs = F.log_softmax(logits, dim=1)
            loss_vec = -torch.sum(log_probs * batch_y, dim=1)  # (B,)
            loss = (loss_vec * sample_weights).mean()
            targets = batch_y.argmax(dim=1)
        else:
            targets = batch_y.long()
            loss = label_smoothing_loss(logits, targets, smoothing=0.1)            
        
        ## Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ## Accmulate loss and accuracy
        total_loss += loss.item() * batch_x.size(0)
        total_correct += (logits.argmax(dim=1) == targets).sum().item()
        total_samples += batch_x.size(0)
        
    if lr_scheduler is not None:
        scheduler.step()

    ## Normalize loss and accuracy
    train_loss = total_loss/total_samples
    train_acc  = total_correct/total_samples

    return train_loss, train_acc

In [22]:
def evaluate(model, val_loader, lbl_encoder):
    model.eval()
    val_loss = 0.0
    total_correct, total_samples = 0, 0

    all_preds = []
    all_targets = []

    with torch.inference_mode():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
            logits = model(batch_x)

            # For both loss and metrics, assume batch_y is one-hot or integer labels
            if batch_y.ndim == 2:
                targets = batch_y.argmax(dim=1)
            else:
                targets = batch_y

            # Computes loss and predictions
            loss = F.cross_entropy(logits, targets)
            preds = logits.argmax(dim=1)

            # Track batch metrics
            val_loss += loss.item() * batch_x.size(0)
            total_correct += (preds == targets).sum().item()
            total_samples += batch_x.size(0)

            all_preds.append(preds.cpu())
            all_targets.append(targets.cpu())

    # Stack all predictions and targets
    y_pred_all = torch.cat(all_preds).numpy()
    y_val_all  = torch.cat(all_targets).numpy()

    # Compute custom hierarchical F1
    val_f1 = F1_score(y_val_all, y_pred_all, lbl_encoder, choice="weighted")
    val_loss = val_loss / total_samples
    val_acc  = total_correct / total_samples

    return val_loss, val_acc, val_f1

In [23]:
def train_eval(epochs, model, train_loader, val_loader, optimizer, class_weights, lbl_encoder, lr_scheduler=None):
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []
    val_F1 = []
    history = {}
    
    for epoch in range(epochs):
        ## Trains model
        train_loss, train_acc = train(model, train_loader, optimizer, class_weights, lr_scheduler)

        ## Evaluates model
        val_loss, val_acc, val_f1 = evaluate(model, val_loader, lbl_encoder)

        ## Append metrics
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        val_F1.append(val_f1)

        ## Checks early stopping
        if early_stopper is not None:
            early_stopper(model, val_f1)
            if early_stopper.early_stop:
                print(f"\nEarly stopping triggered at epoch {epoch+1}\n")
                break
                
        ## Displays any result
        print(f"Epoch [{epoch+1}/{epochs}]")
        print(f"Train Accuracy: {train_acc:.3f}     Train Loss: {train_loss:.3f}")
        print(f"Val Accuracy:   {val_acc:.3f}       Val F1:     {val_f1:.2f}     Val Loss:   {val_loss:.3f}\n")

    ## Save results to history dictionary
    history["train_losses"] = train_losses
    history["train_accuracies"] = train_accuracies
    history["val_losses"] = val_losses
    history["val_accuracies"] = val_accuracies
    history["val_F1"] = val_F1
    return history

In [24]:
early_stopper = F1EarlyStopping(PATIENCE, verbose=True)

train_history = train_eval(
    EPOCHS, 
    model, 
    train_loader, 
    val_loader, 
    optimizer, 
    class_weights, 
    label_encoder, 
    scheduler
)

Epoch [1/200]
Train Accuracy: 0.148     Train Loss: 2.823
Val Accuracy:   0.324       Val F1:     0.55     Val Loss:   2.132

Epoch [2/200]
Train Accuracy: 0.257     Train Loss: 2.337
Val Accuracy:   0.408       Val F1:     0.62     Val Loss:   1.824

Epoch [3/200]
Train Accuracy: 0.336     Train Loss: 2.113
Val Accuracy:   0.454       Val F1:     0.65     Val Loss:   1.600

Epoch [4/200]
Train Accuracy: 0.382     Train Loss: 1.965
Val Accuracy:   0.515       Val F1:     0.69     Val Loss:   1.427

Epoch [5/200]
Train Accuracy: 0.434     Train Loss: 1.876
Val Accuracy:   0.521       Val F1:     0.70     Val Loss:   1.332

Epoch [6/200]
Train Accuracy: 0.478     Train Loss: 1.743
Val Accuracy:   0.553       Val F1:     0.72     Val Loss:   1.250

Epoch [7/200]
Train Accuracy: 0.489     Train Loss: 1.724
Val Accuracy:   0.571       Val F1:     0.72     Val Loss:   1.223

Epoch [8/200]
Train Accuracy: 0.517     Train Loss: 1.657
Val Accuracy:   0.609       Val F1:     0.75     Val Loss:  

In [25]:
joblib.dump(train_history, "train_history.joblib")

['train_history.joblib']