In [72]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mosei_dataset import MOSEIDataset
from tqdm import tqdm
from tbje import TBJENew
import numpy as np
import torch.optim as optim
from torch.nn import init
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
data_path = "tensors.pkl"
train_dataset = MOSEIDataset(data_path, "train")
val_dataset = MOSEIDataset(data_path, "val")

In [147]:
VIDEO_DIM = train_dataset[0][1].shape[1]
TEXT_DIM = train_dataset[0][3].shape[1]
AUDIO_DIM = train_dataset[0][4].shape[1]

In [152]:
def r_squared(y_pred, y_true):
    """
    Compute the coefficient of determination (r^2 score)
    """
    y_true_np = y_true.cpu().detach().numpy()
    y_pred_np = y_pred.cpu().detach().numpy()
    ss_res = np.sum((y_true_np - y_pred_np) ** 2)
    ss_tot = np.sum((y_true_np - np.mean(y_true_np)) ** 2)
    return 1 - ss_res/ss_tot if ss_tot > 0 else 0.0

def mae(preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    Calculate the Mean Absolute Error (MAE) between predictions and labels.
    
    Args:
        preds (torch.Tensor): 1D tensor of predictions.
        labels (torch.Tensor): 1D tensor of ground truth values.
        
    Returns:
        torch.Tensor: A scalar tensor containing the MAE.
    """
    # Ensure both tensors are 1D
    if preds.ndim != 1 or labels.ndim != 1:
        raise ValueError("Both preds and labels must be 1D tensors.")
    return torch.mean(torch.abs(preds - labels))

def pearson_corr(pred, label):
    pred = pred.detach().cpu().numpy()
    label = label.detach().cpu().numpy()
    corr_matrix = np.corrcoef(pred, label)
    # Extract Pearson's r value (off-diagonal element)
    return corr_matrix[0, 1]

def class_from_string(class_path: str, *args, **kwargs):
    # Expecting class_path format "module_name.ClassName", returns the class object, not initiated object
    module_name, class_name = class_path.rsplit(".", 1)
    module = __import__(module_name, fromlist=[class_name])
    cls = getattr(module, class_name)
    return cls

def general_epoch(model, dataloader, loss_fn, optimizer=None, device='cuda'):
    """
    Performs one epoch of training or evaluation for regression.
    
    Args:
        model (torch.nn.Module): the model to train/evaluate.
        dataloader (DataLoader): DataLoader for the current dataset (train or validation).
        loss_fn (callable): the loss function.
        optimizer (torch.optim.Optimizer or None): optimizer. If None, evaluation mode.
        device (str): device to run the computation on.
    
    Returns:
        epoch_loss (float): average loss over the epoch.
        epoch_r (float): r-squared value over the epoch.
    """
    if optimizer is not None:
        model.train()  # training mode
    else:
        model.eval()   # evaluation mode
    model = model.to(device=device, dtype=torch.float64)

    running_loss = 0.0
    all_preds = []
    all_labels = []
    total_samples = 0
    target_index = 0

    for *inputs, labels in tqdm(dataloader): 
        inputs = [
            inputs[1].to(device, dtype=torch.float64), 
            inputs[3].to(device, dtype=torch.float64), 
            inputs[4].to(device, dtype=torch.float64)
        ]
        labels = labels[:, target_index].to(device, dtype=torch.float64) # regression targets

        outputs = model(*inputs)
        loss = loss_fn(outputs, labels.unsqueeze(1))
        
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
        all_preds.append(outputs.squeeze())
        all_labels.append(labels)
        total_samples += inputs[0].size(0)
    
    epoch_loss = running_loss / total_samples
    all_preds = torch.nan_to_num(torch.cat(all_preds), nan=0.0)
    all_labels = torch.cat(all_labels)
    epoch_r = pearson_corr(all_preds, all_labels)
    epoch_mae = mae(all_preds, all_labels)
    epoch_r2 = r_squared(all_preds, all_labels)
    
    return epoch_loss, epoch_r, epoch_mae, epoch_r2

def train_model(model, train_loader, valid_loader, loss_fn, optimizer, workdir,
                device='cuda', num_epochs=25, patience=5):
    """
    Train the regression model with mini-batch training, per-epoch validation, and early stopping.
    
    Args:
        model (torch.nn.Module): the model to train.
        train_loader (DataLoader): DataLoader for the training set.
        valid_loader (DataLoader): DataLoader for the validation set.
        loss_fn (callable): loss function.
        optimizer (torch.optim.Optimizer): optimizer.
        device (str): device to use.
        num_epochs (int): maximum number of epochs.
        patience (int): epochs to wait before early stopping.
    
    Returns:
        model: the trained model (best state).
    """
    best_val_loss = float('inf')
    epochs_without_improve = 0
    train_loss_trace, train_r_trace, val_loss_trace, val_r_trace = [], [], [], []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training phase
        train_loss, train_r, train_mae, train_r2 = general_epoch(model, train_loader, loss_fn, optimizer, device)
        train_loss_trace.append(train_loss)
        train_r_trace.append(train_r)
        print(f"Train Loss: {train_loss:.4f} r: {train_r:.4f} MAE: {train_mae:.4f} R^2: {train_r2}")
        
        # Validation phase
        val_loss, val_r, val_mae, val_r2 = general_epoch(model, valid_loader, loss_fn, optimizer=None, device=device)
        val_loss_trace.append(val_loss)
        val_r_trace.append(val_r)
        print(f"Val   Loss: {val_loss:.4f} r: {val_r:.4f} MAE: {val_mae:.4f} R^2: {val_r2}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improve = 0
            print("Validation loss improved, saving model...")
        else:
            epochs_without_improve += 1
            print(f"No improvement for {epochs_without_improve} epoch(s).")
        torch.save(model.state_dict(), f"{workdir}/epoch{epoch}_valloss{val_loss:.2f}_valmae{val_mae:.2f}.pt")
        
        if epochs_without_improve >= patience:
            print("Early stopping triggered.")
            break

        print("-" * 30)

    return model

def initialize_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            init.constant_(m.bias, 0)

In [161]:
INTERNAL_DIM = 256
N_HEADS = 8
MLP_DIM = 1024
N_LAYERS = 2
BATCH_SIZE = 32
LR = 1e-6
WEIGHT_DECAY = 0
WORKDIR = "tbje-test-round"
EPOCHS = 50
PATIENCE = 5

In [162]:
model = TBJENew(VIDEO_DIM, TEXT_DIM, AUDIO_DIM, INTERNAL_DIM, N_HEADS, MLP_DIM, N_LAYERS)
model.apply(initialize_weights)
train_dataloader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, BATCH_SIZE, shuffle=True)
loss_fn = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [163]:
with torch.autograd.set_detect_anomaly(True):
    train_model(model, train_dataloader, val_dataloader, loss_fn, optimizer, WORKDIR, num_epochs=EPOCHS, patience=PATIENCE)

Epoch 1/50


 19%|█▉        | 6/32 [00:16<01:13,  2.83s/it]


KeyboardInterrupt: 