In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [None]:
# Useful probing dataset class definitions

class LogisticRegressionCLSDataset(Dataset):
    def __init__(self, embeddings, labels):

        self.embeddings = embeddings
        if isinstance(labels, np.ndarray):
            self.labels = torch.from_numpy(labels).long()
        else:
            self.labels = labels.long()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.embeddings[idx]
        y = self.labels[idx].item()
        return x, y

class LogisticRegressionPatchDataset(Dataset):
    def __init__(self, embeddings, labels, flatten=False):
        self.embeddings = embeddings
        if isinstance(labels, np.ndarray):
            self.labels = torch.from_numpy(labels).long()
        else:
            self.labels = labels.long()
        self.flatten = flatten

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.embeddings[idx]
        if self.flatten:
            x = x.view(-1)
        y = self.labels[idx].item()
        return x, y


In [None]:
# The Linear Probe Classes
class NumerosityLinearProbeCLS(nn.Module):
    def __init__(self, input_dim=4096, num_classes=10):
        super(NumerosityLinearProbeCLS, self).__init__()
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=-1)
        return logits, probs
class NumerosityLinearProbeFlattenPatch(nn.Module):
    def __init__(self, input_tokens=576, input_dim=4096, num_classes=10):
        super(NumerosityLinearProbeFlattenPatch, self).__init__()
        self.input_tokens = input_tokens
        self.input_dim = input_dim
        self.classifier = nn.Linear(input_tokens * input_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, 577, 4096)
        x_flat = x.view(x.size(0), -1)  # Flatten to (batch_size, 577 * 4096)
        logits = self.classifier(x_flat)
        probs = F.softmax(logits, dim=-1)
        return logits, probs

In [None]:
def create_stratified_datasets(embeddings, labels, patch=False, flatten=False):
    if hasattr(embeddings, 'cpu'):
        embeddings_np = embeddings.cpu().numpy()
    else:
        embeddings_np = embeddings

    if hasattr(labels, 'cpu'):
        labels_np = labels.cpu().numpy()
    else:
        labels_np = labels

    X_trainval, X_test, y_trainval, y_test = train_test_split(
        embeddings_np,
        labels_np,
        test_size=0.2,
        stratify=labels_np,
        random_state=42
    )

    X_train, X_val, y_train, y_val = train_test_split(
        X_trainval,
        y_trainval,
        test_size=0.25,
        stratify=y_trainval,
        random_state=42
    )

    X_train = torch.tensor(X_train, dtype=torch.float32)
    X_val   = torch.tensor(X_val, dtype=torch.float32)
    X_test  = torch.tensor(X_test, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_val   = torch.tensor(y_val, dtype=torch.long)
    y_test  = torch.tensor(y_test, dtype=torch.long)

    if patch:
        train_dataset = LogisticRegressionPatchDataset(X_train, y_train, flatten=flatten)
        val_dataset   = LogisticRegressionPatchDataset(X_val, y_val, flatten=flatten)
        test_dataset  = LogisticRegressionPatchDataset(X_test, y_test, flatten=flatten)
    else:
        train_dataset = LogisticRegressionCLSDataset(X_train, y_train)
        val_dataset   = LogisticRegressionCLSDataset(X_val, y_val)
        test_dataset  = LogisticRegressionCLSDataset(X_test, y_test)

    print(f"Dataset splits:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Val:   {len(val_dataset)} samples")
    print(f"  Test:  {len(test_dataset)} samples")
    print(f"  Total: {len(train_dataset) + len(val_dataset) + len(test_dataset)} samples")

    print(f"\nClass distribution:")
    unique_train, counts_train = np.unique(y_train, return_counts=True)
    unique_val, counts_val     = np.unique(y_val, return_counts=True)
    unique_test, counts_test   = np.unique(y_test, return_counts=True)

    print("Train:", dict(zip(unique_train, counts_train)))
    print("Val:  ", dict(zip(unique_val, counts_val)))
    print("Test: ", dict(zip(unique_test, counts_test)))

    return {
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'test_dataset': test_dataset,
        'train_embeddings': X_train,
        'val_embeddings': X_val,
        'test_embeddings': X_test,
        'train_labels': y_train,
        'val_labels': y_val,
        'test_labels': y_test
    }

def simple_train_model(model, train_loader, val_loader, optimizer, scheduler, device,
                       max_epochs=50, patience=10):

    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    epochs_no_improve = 0

    train_losses = []
    val_losses = []

    for epoch in range(max_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_total = 0

        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
            x_batch = x_batch.to(device, dtype=torch.float32, non_blocking=True)
            y_batch = y_batch.to(device, dtype=torch.long, non_blocking=True) - 1

            optimizer.zero_grad()

            try:
                logits, _ = model(x_batch)
                loss = criterion(logits, y_batch)

                if torch.isnan(loss) or torch.isinf(loss):
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                train_loss += loss.item() * x_batch.size(0)
                train_total += x_batch.size(0)

            except RuntimeError as e:
                print(f"⚠️ RuntimeError in training batch {batch_idx}: {e}")
                continue

        if train_total > 0:
            train_loss = train_loss / train_total
        else:
            break

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_total = 0

        with torch.no_grad():
            for batch_idx, (x_batch, y_batch) in enumerate(val_loader):
                try:
                    x_batch = x_batch.to(device, dtype=torch.float32, non_blocking=True)
                    y_batch = y_batch.to(device, dtype=torch.long, non_blocking=True) - 1

                    logits, _ = model(x_batch)
                    loss = criterion(logits, y_batch)

                    if torch.isnan(loss) or torch.isinf(loss):
                        continue

                    val_loss += loss.item() * x_batch.size(0)
                    val_total += x_batch.size(0)

                except RuntimeError as e:
                    print(f"⚠️ RuntimeError in validation batch {batch_idx}: {e}")
                    continue

        if val_total > 0:
            val_loss = val_loss / val_total
        else:
            val_loss = float('inf')

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1:2d}/{max_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if not (torch.isnan(torch.tensor(val_loss)) or torch.isinf(torch.tensor(val_loss))):
            scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pt')
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    try:
        model.load_state_dict(torch.load('best_model.pt', map_location=device))
    except FileNotFoundError:
        print("⚠️ Best model file not found. Returning last model state.")

    return model, train_losses, val_losses