# New start

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import pandas as pd
import wandb
import gc
from sklearn.model_selection import train_test_split
# Create dataset classes (using your BalancedDataset approach) and training function
class BalancedDataset(Dataset):
    def __init__(self, X, y, limit_per_label=1600):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Custom Dataset for validation with limit per class
class BalancedValidationDataset(Dataset):
    def __init__(self, X, y, limit_per_label=400):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Training function (similar to your ConvNet setup but using WandB)
def train_model_vit(model, train_loader, val_loader, test_loader, num_epochs=10, lr=1e-4, patience=5, device='cuda'):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Re-sample training data at the start of each epoch
        train_loader.dataset.re_sample()
        model.train()
        train_loss = 0.0
        
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_accuracy = (outputs.argmax(dim=1) == y_batch).float().mean()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)
                val_loss += loss.item() * X_val.size(0)
                val_accuracy = (outputs.argmax(dim=1) == y_val).float().mean()
        
        # Test phase
        test_loss = 0.0
        test_accuracy = 0.0
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)
                test_loss += loss.item() * X_test.size(0)
                test_accuracy = (outputs.argmax(dim=1) == y_test).float().mean()


        # Log metrics to WandB
        train_loss /= len(train_loader.dataset)
        val_loss /= len(val_loader.dataset)
        test_loss /= len(test_loader.dataset)
        wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch, 
                   "train_accuracy": train_accuracy.item(), "val_accuracy": val_accuracy.item(), 
                   "test_accuracy": test_accuracy.item(), "test_loss": test_loss})
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    return model
class VisionTransformer1D(nn.Module):
    def __init__(self, input_size=3748, num_classes=4, patch_size=5, dim=128, depth=12, heads=16, mlp_dim=256, dropout=0.2):
        super(VisionTransformer1D, self).__init__()

        # Store patch size and dimensionality for embedding
        self.patch_size = patch_size
        self.dim = dim

        # Patch Embedding layer
        self.patch_embed = nn.Linear(patch_size, dim)

        # Positional Encoding (initialize to a reasonable size, but we’ll adjust it dynamically)
        max_patches = (input_size + patch_size - 1) // patch_size  # Approximate max patches
        self.pos_embedding = nn.Parameter(torch.randn(1, max_patches, dim))

        # Transformer blocks
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout),
            depth
        )

        # MLP Head
        self.fc = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        # Handle input dimensions and ensure padding for patch divisibility
        batch_size, channels, seq_len = x.shape  # Assuming x has 3 dimensions
        x = x.squeeze(1) if channels == 1 else x  # Remove channel dimension if it's 1

        # Calculate required padding for divisibility by patch_size and pad input
        pad_length = (self.patch_size - (seq_len % self.patch_size)) % self.patch_size
        x = nn.functional.pad(x, (0, pad_length))
        
        # Dynamically calculate number of patches after padding
        num_patches = x.size(1) // self.patch_size
        x = x.view(batch_size, num_patches, self.patch_size)  # Reshape to patches
        
        # Embed patches and add positional encoding (resize pos_embedding if needed)
        if self.pos_embedding.size(1) != num_patches:
            self.pos_embedding = nn.Parameter(self.pos_embedding[:, :num_patches, :])
        x = self.patch_embed(x) + self.pos_embedding

        # Transformer forward pass
        x = self.transformer(x)

        # Classify based on the first token representation
        x = self.fc(x[:, 0])

        return x


In [10]:
batch_size = 128



# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    # Load and preprocess data
    X = pd.read_pickle("Pickles/fusionv0/train.pkl")
    y = X["label"]
    label_mapping = {'star': 0, 'binary_star': 1, 'galaxy': 2, 'agn': 3}
    y = y.map(label_mapping).values
    
    X = X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label"], axis=1).values
    
    # Read test data
    X_test = pd.read_pickle("Pickles/fusionv0/test.pkl")
    y_test = X_test["label"].map(label_mapping).values
    X_test = X_test.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label"], axis=1).values
    
    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
    X_val = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_val = torch.tensor(y_val, dtype=torch.long)

    train_dataset = BalancedDataset(X_train, y_train)
    val_dataset = BalancedValidationDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(BalancedValidationDataset(torch.tensor(X_test, dtype=torch.float32).unsqueeze(1),
                                                    torch.tensor(y_test, dtype=torch.long)), batch_size=batch_size, shuffle=False)


test resampling

In [28]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=128, dim=32, depth=12, heads=16, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=30)
    
    # Save the model and finish WandB session
    wandb.finish()

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Early stopping triggered.


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇██
test_accuracy,▁▁▄▅▄▃▃▇▆▆▆██▇▆▆▆▆▆▆▇▆▅▆▆▆▆▇▆▆▆▇█▆▇▇▇▇▆▇
test_loss,█▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▃▃▄▆▆▆▅▅▆▄▃▇██▆▆▄▅▇▆▇▇▆▅▆██▇▅▇▇▅▇▇▆█▅▅
train_loss,█▇▇▇▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▄▅▄▆▅▆▆▇▇▇▇▇▇▇█▇████████████▇██▇████
val_loss,█▇▆▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,90.0
test_accuracy,0.425
test_loss,0.93941
train_accuracy,0.62667
train_loss,0.95117
val_accuracy,0.52756
val_loss,0.96227


In [29]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=256, dim=32, depth=12, heads=16, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=30)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇███
test_accuracy,▁▁▄▄▅▇▆▆▆▇▇██▇▇████▇████████████████████
test_loss,█▇▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▃▄▃▆▆▆▇▅▆▇▇▆▇█▅▇▇▇▇▇▇█▇▇▆▆▇▆▇▆▆▆█▇▇▇▇▇
train_loss,█▇▇▆▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▂▁▂▁▁▁▂▁▁
val_accuracy,▁▂▄▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇████████████████
val_loss,██▇▆▅▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,82.0
test_accuracy,0.625
test_loss,0.85897
train_accuracy,0.61333
train_loss,0.86807
val_accuracy,0.58268
val_loss,0.8871


In [30]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=128, dim=128, depth=12, heads=16, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=30)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▂▂▂▂▁▄▅▄█▄▇▄▄▆▇▇█▄▄▇▆▇▇▇▇▇▇▆▆▆▆▇▇▆▆▇▇▇▇▆
test_loss,██▇▇▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▂▁▃▃▄▅▆▃▇▅▄▆▆▆▆▃▅█▅▅▆▇▅▆▇▇▆▅▆▆▇▇▇▇▇▇▇▇▇▆
train_loss,█▇▇▆▆▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▁▃▄▇▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▆▇▇▇█▇▇▇█▇▇
val_loss,██▇▇▆▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,53.0
test_accuracy,0.55
test_loss,0.92131
train_accuracy,0.58667
train_loss,0.8899
val_accuracy,0.55118
val_loss,0.92642


In [31]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=512, dim=128, depth=12, heads=16, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=30)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██
test_accuracy,▁▂▂▃▄▅▅▅▅▆▇▆▇▇▇▇▇▇▇▇████▇███████████████
test_loss,█▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▅▄▆▄▆▆▅▅▆▄▆▄▆▅▅▆▆▆▆▅▆▆▇▅▇▆▆▆▃█▇▅▄▆▅▆▇▅
train_loss,█▆▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▄▆▇▇█▇▇▇▇▇▇▇▇▇▇▇▇██▇█▇████████████████
val_loss,█▇▆▅▅▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,65.0
test_accuracy,0.75
test_loss,0.69016
train_accuracy,0.62667
train_loss,0.70097
val_accuracy,0.66142
val_loss,0.72708


In [32]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=1024, dim=128, depth=12, heads=16, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇███
test_accuracy,▁▃▄▅▆▇█▇██▇▇▇▇▇▇█▇▇▇▇▇█▇▇▇▇▇▇█▇▇█▇███▇█▇
test_loss,█▇▅▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▅▁▆▃▆▅▆█▇▅▂▇▄▅▅▅▄▆▃▃▆▆▃▄▄▆▆▅██▆▆▆▅▆▃▄
train_loss,█▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▆▆▆▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇██████
val_loss,█▇▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,143.0
test_accuracy,0.7
test_loss,0.60489
train_accuracy,0.69333
train_loss,0.55249
val_accuracy,0.73228
val_loss,0.65096


In [None]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=937, dim=140, depth=20, heads=7, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

In [33]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=128, depth=12, heads=4, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
test_accuracy,▁▁▄▅▅▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇▇▇▇▇▇▇▆▆▇▇▇▆▇▇▆▇▅
test_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▂▂▂▂▂▂▂▂▂
train_accuracy,▁▃▅▄▇▆▄▆▆▅▄▆▅▅█▇▅▅▆▇▆█▆█▅▅▇▆▇▅▇▆▇▇▆▆▆█▇▅
train_loss,█▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆▆▆▆▇▇▇▇▆▇▆▆▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇
val_loss,█▅▃▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▂▃▃▃▃▃▃▃

0,1
epoch,123.0
test_accuracy,0.725
test_loss,0.65202
train_accuracy,0.88
train_loss,0.46366
val_accuracy,0.73228
val_loss,0.71223


In [38]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=128, depth=40, heads=4, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇████
test_accuracy,█████▁▁████▁▁▁█▁▁▁▁▁▁▁█▁█▁▁▁▁▁▁▁▁▁██▁▁▁▁
test_loss,▄▅▄█▅▄▅▇▄▆▆▅▄▅▄▅▅▅▃▄▅▅▂▂▂▃▅▅▄▄▂▁▁▃▂▅▂▄▁▂
train_accuracy,▃▁▂▁▃▄▅▇▃▄▅▃▅▄▆▆▄▂▂▄▅▇▆▃▆▆▅▃▄▁▃█▅▆▇▆▆▄▄▇
train_loss,██▇█▇▆▆▇▇▇▅▃▄▃▅▅▃▄▄▃▃▄▄▄▂▄▂▃▃▃▃▂▂▁▂▁▃▁▂▂
val_accuracy,██▄█▁█▁▁▁██▄▄█▄█▄████▄███▄▄▄▄█▄▄█▄▄███▄█
val_loss,▇▇▇▇█▇▇▆▆▆▇▅▆▅▆▇▅▆▆▅▅▄▅▄▄▄▄▄▃█▃▃▄▃▄▂▂▁▃▂

0,1
epoch,119.0
test_accuracy,0.325
test_loss,1.38614
train_accuracy,0.26667
train_loss,1.38079
val_accuracy,0.30709
val_loss,1.37753


In [34]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=256, depth=20, heads=4, mlp_dim=256, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-5, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇█
test_accuracy,▁▂▁▄▇▇▇▇▇▇█▇█████████████▇▇▇▇▇█▇▇▆█▇▇▇▇▇
test_loss,██▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▂▂▂▂▂▂
train_accuracy,▁▂▄▄▇▅▇▆▇▇▆▇▇▆▆▇▇▆▇▇▆▇▇▇▆▇▇▇█▇▇██▇█▇▆▆█▇
train_loss,█▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▄▄▇▇▇▇▇▇▇▇▇▇▇▇█▇█████▇████████████████
val_loss,█▇▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▂▂▁▁▁▂▂▁▂▂▂▂▁▂▂

0,1
epoch,116.0
test_accuracy,0.675
test_loss,0.70446
train_accuracy,0.72
train_loss,0.46491
val_accuracy,0.71654
val_loss,0.73869


In [35]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=256, depth=20, heads=4, mlp_dim=512, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-6, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁▂▂▁▂▅▅▄▆▅▆▆▇▅▇████▇███▇████████████████
test_loss,█████▇▆▄▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▂▁▁▁▂▄▄▅▅▆▇▅▆▆▆▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▆▅█▆▇█▆▆
train_loss,██████▇▅▄▄▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▁▅▅▅▅▅▅▆▇▇▇█████████████████████████
val_loss,███▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,197.0
test_accuracy,0.75
test_loss,0.5939
train_accuracy,0.73333
train_loss,0.60268
val_accuracy,0.72441
val_loss,0.6403


In [36]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=512, depth=20, heads=4, mlp_dim=512, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-6, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█
test_accuracy,▁▂▃▄▅▅▆▆▇▇▇▇▇▇▇████▇█▇█▇█▇██▇▇▇▇█▇███▇██
test_loss,█████▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▄▅▆▇▆▇▇▇▆▅▆▇▇▇▇▆▆▆▇▇▇█▇▇▇▆▇▇█▆▇█▇█▇▇▇█
train_loss,████▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▁▂▄▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇██████████████
val_loss,███▆▅▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,167.0
test_accuracy,0.7
test_loss,0.58999
train_accuracy,0.70667
train_loss,0.5827
val_accuracy,0.73228
val_loss,0.63297


In [37]:
    # Initialize WandB project
    wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
    # Initialize and train the model
    model_vit = VisionTransformer1D(num_classes=4, patch_size=3748, dim=512, depth=20, heads=4, mlp_dim=126, dropout=0.2)
    trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-6, patience=100)
    
    # Save the model and finish WandB session
    wandb.finish()

Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
test_accuracy,▁▁▁▂▂▄▆▅▅▅▄▃▄▅▅▆▆▆▆▇▆▇▆▇▇▇▇▇▇▇▇▇███▇▇▇▇█
test_loss,███████▇▆▆▄▅▄▄▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▂▁▂▂▃▄▃▄▅▅▆▇▅▆▇▅█▆▇▇▇▇▇█▆▇▆█▇▇▆▆▆█▇▇▇▆▇▇
train_loss,████▇▅▅▄▄▃▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▂▃▅▅▆▆▇▇▇████████████████████████████
val_loss,██████▇▇▇▅▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,163.0
test_accuracy,0.75
test_loss,0.60869
train_accuracy,0.68
train_loss,0.6064
val_accuracy,0.70079
val_loss,0.66507


In [19]:
# Save the model
torch.save(model_vit.state_dict(), "Models/vit_model.pth")

# Print the model summary
print(model_vit)

# Print confusion matrix and classification report  
print_confusion_matrix_vit(trained_model, test_loader)

VisionTransformer1D(
  (patch_embed): Linear(in_features=17, out_features=128, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=128, out_features=4, bias=True)
  )
)


Confusion Matrix: [[187 107  62  44]
 [141 151  73  35]
 [ 53  11 208 104]
 [ 18   6  70 306]]
Classification Report:               precision    recall  f1-score   support

           0       0.47      0.47      0.47       400
           1       0.55      0.38      0.45       400
           2       0.50      0.55      0.53       376
           3       0.63      0.77      0.69       400

    accuracy                           0.54      1576
   macro avg       0.54      0.54      0.53      1576
weighted avg       0.54      0.54      0.53      1576



In [4]:
# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = VisionTransformer1D(input_size=3748, num_classes=4, patch_size=5, dim=64, depth=12, heads=4, mlp_dim=128, dropout=0.2)
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-4, patience=30)

# Save the model and finish WandB session
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▇▇▇▇▆▇▃▇▄█▆▅▆▄▅▆▃▆▄▄▂▇▂▂▄▅▇▆▅▇▅▅▄▁▆▄▄▄▂▄
test_loss,█▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▃▅▄▅▄▄▆▅▇▇▅▇▆█▃▅▇▄▄▅▇▇▅▇▄▇▆▆▇▁▆▆▄▅▆▄█▆
train_loss,█▆▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁
val_accuracy,▇▆▃▄▁▃▂▆▄▆▅▅▄▆▆▅▆▆▆▆█▆▆▇▆▇▇▅▇▆▆█▆▇█▆███▇
val_loss,█▅▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
test_accuracy,0.4
test_loss,1.04408
train_accuracy,0.56
train_loss,1.04456
val_accuracy,0.52756
val_loss,1.03123




Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇█
test_accuracy,▇▆▇█▇▄▆▅▅▂▄▄▅▆▄▂▄▁▂▂▃▁▂▂▂▂▆▂▄▃▂▆▅▂▂▁▂▄▄▂
test_loss,█▇▅▅▄▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁
train_accuracy,▄▅▁▅▃▅▅▅▅█▅▅▅▄▆▂▆▆▆▅▄▅▅▆▄▅▅▇▆▅█▅▇▅▅▇▄▇▅▅
train_loss,█▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
val_accuracy,▁▆▆▇▆█▇▆▆▅▆▇▇▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▇▆▅▆▆▆▆▇▇▇▇▆
val_loss,█▅▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,53.0
test_accuracy,0.35
test_loss,1.04072
train_accuracy,0.46667
train_loss,1.071
val_accuracy,0.50394
val_loss,1.03263


In [5]:
# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = VisionTransformer1D(input_size=3748, num_classes=4, patch_size=25, dim=256, depth=8, heads=8, mlp_dim=256, dropout=0.2)
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-4, patience=30)

# Save the model and finish WandB session
wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁▃▆█▃▁▁▃█▃▃▅▅▅▅▆▃▃▃▆▃▅▆▅▅▅▃▃▆▆█▅▅▆▅▅▅▅▆▅
test_loss,█▄▄▃▃▃▄▂▂▃▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂▃▂▂
train_accuracy,▁▁▃▃▆▇▃▃▄▆▃▄▃▄▄▅▆▃▃▅▅▃▆▃▅▆▅▃▃▂▅▄▄█▇▅▂▆▃▅
train_loss,█▄▄▃▃▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▁▁▁
val_accuracy,▁▄▇▇▅▅▇▆▆█▆▆▆▆▇▇▇▇▆▇▇▆▇▅▆▅▆▆▇▆▆▆▆▅▆▆▆▇▅▆
val_loss,█▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▂▁▁▁▁▂▁▁▂▁▂▁▁▂▂▂▂▂▂▂

0,1
epoch,41.0
test_accuracy,0.55
test_loss,0.99951
train_accuracy,0.61333
train_loss,0.92496
val_accuracy,0.51969
val_loss,0.97739


In [None]:
# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = VisionTransformer1D(input_size=3748, num_classes=4, patch_size=10, dim=1024, depth=10, heads=2, mlp_dim=512, dropout=0.2)
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-4, patience=30)

# Save the model and finish WandB session
wandb.finish()

In [None]:
   
torch.cuda.empty_cache()
# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = VisionTransformer1D(patch_size=10)
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=50, lr=1e-3, patience=10)

# Save the model and finish WandB session
wandb.finish()

# Toy

In [6]:
# Print the model summary
model_vit = VisionTransformer1D(patch_size=10, num_classes=4, dim=64, depth=2, heads=8, mlp_dim=128, dropout=0.1)
print(model_vit)

VisionTransformer1D(
  (patch_embed): Linear(in_features=10, out_features=64, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Sequential(
    (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=64, out_features=4, bias=True)
  )
)


In [7]:
torch.cuda.empty_cache()
# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=50, lr=1e-3, patience=10)
# Save the model and finish WandB session
wandb.finish()

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: joaoc (joaoc-university-of-southampton). Use `wandb login --relogin` to force relogin


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Early stopping triggered.


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
test_accuracy,▁▄▄▅▆▅▆▇▆▆▇▆▄▇▇█▇▄▇
test_loss,█▆▄▄▃▄▃▂▃▂▂▂▂▃▂▁▁▂▂
train_accuracy,▂▁▅▅▃▆▆▃▃▅▃█▃▄▃▆▆▄▅
train_loss,█▄▄▃▃▃▂▂▂▂▁▂▂▁▁▁▁▁▁
val_accuracy,▁▃▄█▅▂▅▅▂▆▂▃▆▅▆▄▆▇▇
val_loss,█▅▅▇▄▄▃▄▄▂▂▃▃▃▂▂▁▃▄

0,1
epoch,18.0
test_accuracy,6.87813
test_loss,1.04064
train_accuracy,0.54667
train_loss,1.02085
val_accuracy,0.51969
val_loss,1.04904


: 

# SpectraTR Code

In [20]:
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # self.to_qkv = nn.Linear(dim, inner_dim , bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # qkv = self.to_qkv(x).chunk(1, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        #self.to_patch_embedding = nn.Sequential(
        #    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        #    nn.Linear(patch_dim, dim),
        #)
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (p s) -> b (p c) s', p=patch_dim),
            nn.Linear(patch_dim, dim)
        )


        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img): 
        print(img.shape)
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [24]:
   
torch.cuda.empty_cache()
# Initialize WandB project
#wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = ViT(patch_size=2, image_size=(3748), num_classes=4, dim=64, depth=2, heads=8, mlp_dim=128, dropout=0.1)
print(model_vit)
# print the number of parameters
print('Number of parameters: ', sum(p.numel() for p in model_vit.parameters()))

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (p s) -> b (p c) s', p=4)
    (1): Linear(in_features=4, out_features=64, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-1): 2 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=64, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Dropout

In [25]:
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=50, lr=1e-4, patience=10)

# Save the model and finish WandB session
wandb.finish()

torch.Size([128, 1, 3748])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x937 and 4x64)