# New start

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import pandas as pd
import wandb
import gc
from sklearn.model_selection import train_test_split
# Create dataset classes (using your BalancedDataset approach) and training function
class BalancedDataset(Dataset):
    def __init__(self, X, y, limit_per_label=1600):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Custom Dataset for validation with limit per class
class BalancedValidationDataset(Dataset):
    def __init__(self, X, y, limit_per_label=400):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices
    
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Training function (similar to your ConvNet setup but using WandB)
def train_model_vit(model, train_loader, val_loader, test_loader, num_epochs=500, lr=1e-4, max_patience=5, device='cuda'):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_test_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Re-sample training data at the start of each epoch
        train_loader.dataset.re_sample()
        model.train()
        train_loss = 0.0
        
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_accuracy = (outputs.argmax(dim=1) == y_batch).float().mean()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)
                val_loss += loss.item() * X_val.size(0)
                val_accuracy = (outputs.argmax(dim=1) == y_val).float().mean()
        
        # Test phase
        test_loss = 0.0
        test_accuracy = 0.0
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)
                test_loss += loss.item() * X_test.size(0)
                test_accuracy = (outputs.argmax(dim=1) == y_test).float().mean()


        # Log metrics to WandB
        train_loss /= len(train_loader.dataset)
        val_loss /= len(val_loader.dataset)
        test_loss /= len(test_loader.dataset)
        wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch, 
                   "train_accuracy": train_accuracy.item(), "val_accuracy": val_accuracy.item(), 
                   "test_accuracy": test_accuracy.item(), "test_loss": test_loss})
        
        # Early stopping
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            patience = max_patience
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    return model
class VisionTransformer1D(nn.Module):
    def __init__(self, input_size=3748, num_classes=4, patch_size=5, dim=128, depth=12, heads=16, mlp_dim=256, dropout=0.2):
        super(VisionTransformer1D, self).__init__()

        # Store patch size and dimensionality for embedding
        self.patch_size = patch_size
        self.dim = dim

        # Patch Embedding layer
        self.patch_embed = nn.Linear(patch_size, dim)

        # Positional Encoding (initialize to a reasonable size, but we’ll adjust it dynamically)
        max_patches = (input_size + patch_size - 1) // patch_size  # Approximate max patches
        self.pos_embedding = nn.Parameter(torch.randn(1, max_patches, dim))

        # Transformer blocks
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)

        # MLP Head
        self.fc = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

    def forward(self, x):
        # Handle input dimensions and ensure padding for patch divisibility
        batch_size, channels, seq_len = x.shape  # Assuming x has 3 dimensions
        x = x.squeeze(1) if channels == 1 else x  # Remove channel dimension if it's 1

        # Calculate required padding for divisibility by patch_size and pad input
        pad_length = (self.patch_size - (seq_len % self.patch_size)) % self.patch_size
        x = nn.functional.pad(x, (0, pad_length))
        
        # Dynamically calculate number of patches after padding
        num_patches = x.size(1) // self.patch_size
        x = x.view(batch_size, num_patches, self.patch_size)  # Reshape to patches
        
        # Embed patches and add positional encoding (resize pos_embedding if needed)
        if self.pos_embedding.size(1) != num_patches:
            self.pos_embedding = nn.Parameter(self.pos_embedding[:, :num_patches, :])
        x = self.patch_embed(x) + self.pos_embedding

        # Transformer forward pass
        x = self.transformer(x)

        # Classify based on the first token representation
        x = self.fc(x[:, 0])

        return x


In [3]:
batch_size = 256



# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    # Load and preprocess data
    X = pd.read_pickle("Pickles/fusionv0/trainv2.pkl")
    y = X["label"]
    label_mapping = {'star': 0, 'binary_star': 1, 'galaxy': 2, 'agn': 3}
    y = y.map(label_mapping).values
    X = X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Read test data
    X_test = pd.read_pickle("Pickles/fusionv0/testv2.pkl")
    y_test = X_test["label"].map(label_mapping).values
    X_test = X_test.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
    X_val = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_val = torch.tensor(y_val, dtype=torch.long)

    train_dataset = BalancedDataset(X_train, y_train)
    val_dataset = BalancedValidationDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(BalancedValidationDataset(torch.tensor(X_test, dtype=torch.float32).unsqueeze(1),
                                                    torch.tensor(y_test, dtype=torch.long)), batch_size=batch_size, shuffle=False)


In [None]:
# Define the hyperparameters
num_classes = 4
patch_size = 3748
dim = 128
depth = 20
heads = 20
mlp_dim = 512
dropout = 0.3
batch_size = 128
lr = 1e-7
patience = 150
num_epochs = 1000


# Define the config dictionary object
config = {"num_classes": num_classes, "patch_size": patch_size, "dim": dim, "depth": depth, "heads": heads, "mlp_dim": mlp_dim, 
          "dropout": dropout, "batch_size": batch_size, "lr": lr, "patience": patience}

# Initialize WandB project
wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
model_vit = VisionTransformer1D(num_classes=num_classes, patch_size=patch_size, dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout)
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=num_epochs, lr=lr, patience=patience)

# Save the model and finish WandB session
wandb.finish()

In [19]:
# Save the model
torch.save(model_vit.state_dict(), "Models/vit_model.pth")

# Print the model summary
print(model_vit)

# Print confusion matrix and classification report  
print_confusion_matrix_vit(trained_model, test_loader)

VisionTransformer1D(
  (patch_embed): Linear(in_features=17, out_features=128, bias=True)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=128, out_features=4, bias=True)
  )
)


Confusion Matrix: [[187 107  62  44]
 [141 151  73  35]
 [ 53  11 208 104]
 [ 18   6  70 306]]
Classification Report:               precision    recall  f1-score   support

           0       0.47      0.47      0.47       400
           1       0.55      0.38      0.45       400
           2       0.50      0.55      0.53       376
           3       0.63      0.77      0.69       400

    accuracy                           0.54      1576
   macro avg       0.54      0.54      0.53      1576
weighted avg       0.54      0.54      0.53      1576



# Sweep

In [None]:
import wandb

# Define the hyperparameters
num_classes = 4
num_epochs = 500
patience = 50

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [1e-7, 1e-6, 1e-5]},
        "patch_size": {"values": [16, 256, 3748]},
        "dim": {"values": [32, 64, 128, 256, 512]},
        "heads": {"values": [8, 16, 32, 64]},
        "mlp_dim": {"values": [512, 1024, 2048]},
        #"batch_size": {"values": [64, 128, 256]},
        "depth": {"values": [6, 12, 20]},
        "dropout": {"values": [0.1, 0.2, 0.3, 0.4]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_vit = VisionTransformer1D(
            num_classes=num_classes,
            patch_size=config.patch_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout
        )
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_vit(
            model_vit,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            lr=config.lr,
            patience=patience,
            device='cuda'
        )


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-classification-vit")
wandb.agent(sweep_id, function=sweep_train, count=24)


Create sweep with ID: 55gfycot
Sweep URL: https://wandb.ai/joaoc-university-of-southampton/spectra-classification-vit/sweeps/55gfycot


wandb: Agent Starting Run: v7nnc5u6 with config:
wandb: 	depth: 12
wandb: 	dim: 256
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 1e-06
wandb: 	mlp_dim: 1024
wandb: 	patch_size: 3748


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


# Sweep with new normalization

In [4]:
import wandb

# Define the hyperparameters
num_classes = 4
num_epochs = 300
patience = 30

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [3e-7, 3e-6, 3e-5]},
        "patch_size": {"values": [3748]},
        "dim": {"values": [16, 64, 128, 512]},
        "heads": {"values": [2, 8, 32]},
        "mlp_dim": {"values": [64, 128, 256]}, # Adjusted for faster runs
        #"batch_size": {"values": [64, 128, 256]},
        "depth": {"values": [3, 10, 20]},
        "dropout": {"values": [0.1, 0.2, 0.3, 0.4]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_vit = VisionTransformer1D(
            num_classes=num_classes,
            patch_size=config.patch_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout
        )
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_vit(
            model_vit,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            lr=config.lr,
            max_patience=patience,
            device='cuda'
        )


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-classification-vit")
wandb.agent(sweep_id, function=sweep_train, count=50)


Create sweep with ID: cdqoepyl
Sweep URL: https://wandb.ai/joaoc-university-of-southampton/spectra-classification-vit/sweeps/cdqoepyl


wandb: Agent Starting Run: p988mh1q with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.2
wandb: 	heads: 2
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
test_accuracy,▁▁▁▅▅▅██████████████████████████████████
test_loss,█▆▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▄▅▆▃▄▃▆▄▃▄▅▁▆▂▅▄▆▆▆▇▄█▅▅▅▄▅▃▆▄▅▅▆▄▃▆▅▄▃▇
train_loss,█▇▆▅▅▄▄▃▃▂▃▃▂▃▃▃▂▂▂▂▁▂▂▂▂▂▂▃▂▁▂▂▂▃▂▃▂▂▂▁
val_accuracy,██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,45.0
test_accuracy,0.35
test_loss,1.38911
train_accuracy,0.23645
train_loss,1.39102
val_accuracy,0.22745
val_loss,1.38435


wandb: Agent Starting Run: fl0880kh with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748




Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test_accuracy,▁▅▅▅▅▅▆▅▆▅▆▇▅▅▅▅▅▆▄▆▆▅▆▄▅▅▃▆▆▅▆▄▅▇▇▆▆▅▇█
test_loss,█▅▃▃▂▂▂▂▁▂▁▁▁▁▁▂▁▁▁▂▁▁▂▂▁▂▂▂▂▂▂▃▃▂▃▃▃▄▃▄
train_accuracy,▁▅▄▅▄▅▅▅▄▅▆▆▆▆▆▅▇▆▇▆▆▆▆▅█▆▇▆▇▇▇▆▆▇▆▇▇▆▇▆
train_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▆▆▅▇▇▇▇▇██▇▇▇▇▇▇█▆▆▇██▇█▆▆███▇██▇▇█▇▇▇
val_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃

0,1
epoch,46.0
test_accuracy,0.775
test_loss,0.69452
train_accuracy,0.76847
train_loss,0.47295
val_accuracy,0.7098
val_loss,0.72937


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: xpyu2bh9 with config:
wandb: 	depth: 10
wandb: 	dim: 64
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
test_accuracy,▁▂▂▇█▇▇▇▆▅▅▅▅▅▅▅▅▇▆▇▇▇▇▇▆▇▇▇▇▇▇▆▇▇▇▇▆▇▆▇
test_loss,███▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁
train_accuracy,▁▁▁▃▃▄▃▄▅▄▅▅▄▆▅▆▆▅▄▇▆▅▇▅▇▆▆▅▆▆▇█▆▇▆██▆██
train_loss,███▇▇▇▇▆▆▆▅▅▅▅▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▂▂▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████
val_loss,███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.475
test_loss,0.92242
train_accuracy,0.63054
train_loss,0.93801
val_accuracy,0.59608
val_loss,0.92652


wandb: Agent Starting Run: fs35fm1f with config:
wandb: 	depth: 3
wandb: 	dim: 16
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
test_accuracy,▁▁▁▃▃▄▄▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████
test_loss,█▇▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
train_accuracy,▁▁▂▁▂▃▃▄▄▃▃▄▄▄▅▅▆▅▆▆▅▅▇▆▇▆▆█▇▇▇▇▇▇▇▇▇█▇█
train_loss,██▇▇▇▇▇▆▆▅▅▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▁▁▁▁▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇██████
val_loss,█▇▇▆▆▅▅▅▅▅▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,1.03355
train_accuracy,0.58621
train_loss,1.0516
val_accuracy,0.63137
val_loss,1.04951


wandb: Agent Starting Run: o098epzf with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.1
wandb: 	heads: 32
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\sit

Run o098epzf errored:
Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_mo

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
test_accuracy,▂▂▂▂▂▂▁▃▃▃▃▃▃▃▃▃▄▂▂▄▆▅▆▆▆▆▇▇██▇▇▇▇▇▇▇▇▇▇
test_loss,████████▇▇▇▇▇▇▇▆▅▅▅▅▄▄▄▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁
train_accuracy,▃▃▃▂▂▁▂▃▁▃▃▁▂▁▂▁▃▃▂▂▆▄▅▃▅▇▆▅█▇█▇█▇▆▇▇▇██
train_loss,█████████████▇▇▇▇▇▇▆▅▅▅▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
val_accuracy,▂▂▂▃▁▁▁▁▁▁▂▄▅▅▅▄▅▅▅▆▆▆▇▇▆▇▆▇▅█▇█▇▇██████
val_loss,████████████▇▇▇▆▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁

0,1
epoch,299.0
test_accuracy,0.5
test_loss,1.00468
train_accuracy,0.51232
train_loss,1.0632
val_accuracy,0.57647
val_loss,1.02069


wandb: Agent Starting Run: nesipibu with config:
wandb: 	depth: 20
wandb: 	dim: 64
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▂▁▄▆▆▆▆▇▆▆▇▇▆▇▇▆▇▇▇█▆▇▇▇▇▆▆▇▆▇▇▇▇█▇▇▇▇█▆
test_loss,██▇▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▁▁▅▇▆▆▆▆▇▇▆▇▆▇█▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇██
train_loss,████▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▃▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇████████████████████
val_loss,██▇▅▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,110.0
test_accuracy,0.6
test_loss,0.59985
train_accuracy,0.72414
train_loss,0.53105
val_accuracy,0.70588
val_loss,0.65305


wandb: Agent Starting Run: ojk86zsv with config:
wandb: 	depth: 3
wandb: 	dim: 16
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁▃▃▅▆▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████████████████
test_loss,█▆▆▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▅▆▄▆▅▅▅▅▆▇▆▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▆▆▇▇▆█▇▇▆█▆
train_loss,██▇▇▆▆▅▅▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▅▆▆▆▇▇▇▇██████████████████████████████
val_loss,██▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,0.66687
train_accuracy,0.65517
train_loss,0.72215
val_accuracy,0.66667
val_loss,0.69161


wandb: Agent Starting Run: 2pmencdd with config:
wandb: 	depth: 20
wandb: 	dim: 64
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁▁▄▄▅▇▆▆▅▆▇▇▇█▇▇▇█▇▇█▇█▇▇█▇▇▇███▇▇▇▇█▇█▇
test_loss,██▇▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▅▅▆▅▆▇▆▆▆██▇▇▇▇▇▇▇▇▆▇▇█▇▇▇▇▇▇▆▇▇▇▇▇▇▇
train_loss,█▇▇▆▅▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▅▆▆▇▇▇▇▇█▇▇▇▇█▇█▇▇█▇▇▇▇█▇▇▇█▇█▇▇▇▇▇▇█
val_loss,██▇▆▅▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,78.0
test_accuracy,0.575
test_loss,0.62236
train_accuracy,0.73399
train_loss,0.56191
val_accuracy,0.68627
val_loss,0.65943


wandb: Agent Starting Run: xpo53z6g with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 32
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇█
test_accuracy,▁▂▃▄▃▂▂▃▄▅▆▆▆▆▆▅▆▅▅▆▅▆▆▆▅▆▆▆▆▇▇█████████
test_loss,████▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▃▄▄▄▄▆▆▅▄▄▅▅▆▆▆▆▆▄▇▆▅▇▅▆▇▇▆▇▆▆▅▆█▅▆▆▅▆
train_loss,███▇▇▆▆▆▆▅▅▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▂▂▂▃▄▄▅▅▅▅▅▅▅▅▅▅▅▅▇▆▇▇▇▇█████████████▇
val_loss,██▇▆▆▅▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.625
test_loss,0.60311
train_accuracy,0.66502
train_loss,0.60239
val_accuracy,0.71373
val_loss,0.63071


wandb: Agent Starting Run: 5doqbn4h with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.3
wandb: 	heads: 32
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█
test_accuracy,▃▃▁▁▁▁▁▁▁▁▁▁▁▁▃▆▃▅▆▇▅▅▆▆▇▅▆▆▆▆▆▇████████
test_loss,████████████▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▁
train_accuracy,▁▂▂▃▃▂▃▁▅▃▃▂▁▃▁▃▃▄▃▅▃▄▃▃▄▄▅▅▅▄▆▆▄▇▆▇▇▇█▆
train_loss,█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁
val_accuracy,▂▂▂▂▂▂▁▂▁▂▂▂▂▃▂▂▂▂▁▃▅▂▂▅▅▇▆▇▇▇▇▇████████
val_loss,████████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▄▄▃▄▃▃▃▃▃▃▃▂▃▁

0,1
epoch,299.0
test_accuracy,0.5
test_loss,1.11765
train_accuracy,0.53202
train_loss,1.14893
val_accuracy,0.54902
val_loss,1.14006


wandb: Agent Starting Run: 1nuywzgu with config:
wandb: 	depth: 10
wandb: 	dim: 64
wandb: 	dropout: 0.2
wandb: 	heads: 2
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
test_accuracy,▂▁▁▁▃▇▇▆▃▃▃▃▃▃▃▆▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▇▇▆█████
test_loss,█████▇▇▇▇▇▆▆▅▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
train_accuracy,▁▂▁▁▁▂▂▄▄▄▅▅▅▅▅▅▆▆▆▅▅▆▆▆▆▇▆█▇▆▆▆▇▇▇█▇█▇█
train_loss,█████▇▇▇▇▇▇▇▇▇▇▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁
val_accuracy,▁▁▁▁▁▁▁▁▁▁▂▃▄▄▄▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██
val_loss,█▇▇▇▇▇▇▇▇▇▆▆▆▆▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁

0,1
epoch,299.0
test_accuracy,0.55
test_loss,0.82241
train_accuracy,0.66995
train_loss,0.87997
val_accuracy,0.64314
val_loss,0.82331


wandb: Agent Starting Run: tr1rczsy with config:
wandb: 	depth: 10
wandb: 	dim: 16
wandb: 	dropout: 0.3
wandb: 	heads: 32
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\sit

Run tr1rczsy errored:
Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_mo

0,1
epoch,▁▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
test_accuracy,▁▄▅▅▅▅▅▆▆▅▇▇▆▆▇▆▆▆▆▆▆▆▆▇▆▆▆▇█▆▇▇█████▇▇▆
test_loss,█▆▅▅▄▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▂▅▅▅▄▃▅▆█▆▇▆▆▆▇▆▄▇▅▅▅▆▄▇▅▄▆▇▆▇▅█▆▇▇▄▄▆
train_loss,█▆▆▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▂▃▅▆▆▆▇▇▇▇▇▇▇▇████████████████████████
val_loss,█▅▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.725
test_loss,0.55162
train_accuracy,0.76355
train_loss,0.52083
val_accuracy,0.72549
val_loss,0.58906


wandb: Agent Starting Run: hkkuul7b with config:
wandb: 	depth: 10
wandb: 	dim: 16
wandb: 	dropout: 0.4
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██████
test_accuracy,▁▁▄▄▅▆▇█▇▇▇▇▇▇▇▇██▇█▇█▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
test_loss,█▇▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▃▄▄▅▄▆▅▅▆▇▇▆▆▆▆▆▆▇▇▇▇▇█▆▇██▇▇▇▇█▇▆█▇▆▇
train_loss,█▇▇▇▇▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███▇██████████
val_loss,██▇▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,0.70675
train_accuracy,0.6601
train_loss,0.76746
val_accuracy,0.67059
val_loss,0.7138


wandb: Agent Starting Run: qawe8en9 with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇██
test_accuracy,▁▁▃▂▃▅▄▄▄▄▄▄▄▄▅▆▆▆▇▆▆▆▇▇▇▇██████████████
test_loss,██████▇▆▆▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train_accuracy,▂▂▁▁▁▁▁▂▃▄▅▄▆▅▅▅▆▇▇▆▇▇▇█▆█▇█▇███▇█▇█▇▇▇▇
train_loss,██████████▇▆▆▅▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▁▁▁▁▃▄▆▆▆▆▆▆▇▇▇▇▇▇█████████████████████
val_loss,██████████████▆▆▆▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.625
test_loss,0.70924
train_accuracy,0.70443
train_loss,0.7393
val_accuracy,0.67059
val_loss,0.71494


wandb: Agent Starting Run: y58ksjfq with config:
wandb: 	depth: 10
wandb: 	dim: 64
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
test_accuracy,▁▃▃▃▄▆▆▇▆▆▆▆▇▆▆▇▇▇▇▇▇▆▆▆▇█▇▇▇▇▆▆▇▆▆▆▆▆▆▅
test_loss,█▆▆▆▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▄▄▆▇▆▆▆▆▆▇▇▆▆▆▆▇▇▆▇▇▇▇▆▆▆▇▇▇▇█▇▇▇▇▇█▇▇
train_loss,█▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▂▂▅▆▆▆▅▆▇█▇▇▇▇▇▇▇▇███▇████▇█▇▇███████
val_loss,█▇▆▅▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.55
test_loss,0.58022
train_accuracy,0.70443
train_loss,0.58836
val_accuracy,0.70196
val_loss,0.6167


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: jtjdco2a with config:
wandb: 	depth: 10
wandb: 	dim: 512
wandb: 	dropout: 0.4
wandb: 	heads: 2
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇█████
test_accuracy,▁▁▃▅▄▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇█▇▇▇▇
test_loss,███▇▇▇▆▅▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▃▅▅▆▆▇▆▆▆▆▇▇▇▇▇▆▆▇▇███▇█▇▇▇█▇▇▇█▇█▇█▇
train_loss,██▇▇▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▄▅▅▆▆▆▆▆▇▇▇▇▇█████████████████████████
val_loss,████▇▆▆▅▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.65
test_loss,0.6248
train_accuracy,0.66995
train_loss,0.64483
val_accuracy,0.67451
val_loss,0.63034


wandb: Agent Starting Run: 52pjybl6 with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.4
wandb: 	heads: 32
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇████
test_accuracy,▁▁▁▃▄▅▅▇▅▆▇▆▆▇▆█▆▅▅▆▆▇▆▆▇▆▆▆▆█▇▆█▆▇▇▇▇▇█
test_loss,█████▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▁▅▅▇▆▆▆▇▇▆▇▇▇▆▆▇▇▇▇██▇▇█▇▇▇█▇█▇▇▇█▇██▇
train_loss,████▇▅▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▁▃▄█▇██████▇███████████████████████████
val_loss,████▆▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂

0,1
epoch,79.0
test_accuracy,0.6
test_loss,0.61599
train_accuracy,0.75369
train_loss,0.52937
val_accuracy,0.70588
val_loss,0.66924


wandb: Agent Starting Run: r2tkzbso with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
test_accuracy,▁▂▃▄▃▅▅▄▄▄▅▅▅▅▅▅▆▆▇▆▇▇█▇▇████████▇██████
test_loss,█▇▇▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▄▆▅▅▆▆▆▅▇▆▇▆█▇▇▆▇▆█▇▇▆▇▇██▇▇▇▇█▇█▇▇▇▇
train_loss,█▇▇▇▆▅▅▅▄▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▅▅▆▆▆▆▆▇▇▇▇█▇████████████████████████
val_loss,█▇▇▆▆▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.675
test_loss,0.66123
train_accuracy,0.6798
train_loss,0.68189
val_accuracy,0.68235
val_loss,0.68292


wandb: Agent Starting Run: wgljwaja with config:
wandb: 	depth: 20
wandb: 	dim: 512
wandb: 	dropout: 0.3
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
test_accuracy,▁▁▁████████████████████████████████
test_loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▅▄▃▄▄▄▄▅▂▃▃▅▂▅▂█▄▄▃▃▄▅█▅▃▃▄▄▄▃▃▅▃
train_loss,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▄▄▄▄▄▄▄▄█▇▄▄▄▄▄▄▄▄▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄
val_loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,34.0
test_accuracy,0.275
test_loss,1.38764
train_accuracy,0.24631
train_loss,1.39746
val_accuracy,0.25882
val_loss,1.38148


wandb: Agent Starting Run: ebdqj2fp with config:
wandb: 	depth: 20
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁▂▃▅▄▄▅▅▆▇▆▇█▇▇▇█▇▇▇▆█▇▆▇▇███▇▇▇██▆█▇▇▇▇
test_loss,███▇▇▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▁▄▅▆▆▇▇▇▇▇▇█▆▇▇▇▇▇█▆▇█▇▇███▇▇▇▇▇█▇█▇▇█
train_loss,█████▇▃▃▃▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▁▂▄▆▆▇▇███████████████████████████████
val_loss,████▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,131.0
test_accuracy,0.675
test_loss,0.61
train_accuracy,0.76355
train_loss,0.55212
val_accuracy,0.71373
val_loss,0.65381


wandb: Agent Starting Run: 3r1lgqpl with config:
wandb: 	depth: 20
wandb: 	dim: 64
wandb: 	dropout: 0.3
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇█
test_accuracy,▆▆▆▆▆▆█▃▃▃▁▁▁▁▁▁▁▁▁▃▄▃▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
test_loss,██▇▇▇▇▇▇▇▇▆▆▆▆▆▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁
train_accuracy,▂▂▂▃▃▂▄▃▂▂▁▃▃▁▃▄▃▂▃▄▃▁▄▂▁▄▄▅▄▄▅▆▅▂▅▄▆█▆▅
train_loss,███▇▇▇▇▇▇▇▇▇▇▇▆▅▅▅▅▅▅▅▅▅▅▅▄▄▄▄▃▃▃▂▂▂▁▁▁▁
val_accuracy,▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▃▃▄█▇▇▇▇▇▇▆▆▆▆▅▅▆▆▆▆▆▆▇▆
val_loss,█▇▇▇▇▇▇▇▇▇▇▇▆▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁

0,1
epoch,299.0
test_accuracy,0.325
test_loss,1.29538
train_accuracy,0.40394
train_loss,1.30239
val_accuracy,0.37647
val_loss,1.29195


wandb: Agent Starting Run: zm7za4s8 with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▆▇▇▇█
test_accuracy,▁▃▃▃▃▃▃▃▃▃▃▃▃▃▄▄▃▄▁▃▂▃▇▇▇█▇▆▇▆▆▅▆▇▆▆▆▆▇▇
test_loss,███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁
train_accuracy,▂▄▁▃▂▃▄▃▃▃▃▂▅▃▃▃▅▂▃▃▂▅▄▆▆▆▆▆▇▅█▇▇▇▄▇▇▅▆▆
train_loss,██████████████▇▇▇▇▇█▇▆▆▆▆▅▄▄▄▄▃▃▃▃▃▃▃▂▂▁
val_accuracy,▃▂▂▃▂▂▂▂▂▁▁▂▂▃▃▂▂▄▆▅▇███████████████████
val_loss,█████████████████████▇▇▇▆▆▅▄▄▃▃▃▃▃▂▂▂▂▂▁

0,1
epoch,299.0
test_accuracy,0.45
test_loss,1.16718
train_accuracy,0.36946
train_loss,1.25884
val_accuracy,0.5098
val_loss,1.17724


wandb: Agent Starting Run: 47xwlesd with config:
wandb: 	depth: 10
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 32
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁▁▁▂▁▃▃▃▄▄▅▅▆▅▄█▇▇▆▆▇▇▇▇▇█▇▆▆▆▆▇▇▇▇▇█▇▇█
test_loss,████▆▅▄▄▃▃▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▂▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▆▇▇█▇▇▇▇▆▇▇▇▆▇▇█▇█
train_loss,██▇▇▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▅▅▆▅▇▇▇▇▇█▇▇██▇▇█▇█▇▇▇▇▇▇███████████
val_loss,█▆▆▆▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,0.60135
train_accuracy,0.78325
train_loss,0.63081
val_accuracy,0.72157
val_loss,0.61598


wandb: Agent Starting Run: nldig0ly with config:
wandb: 	depth: 3
wandb: 	dim: 128
wandb: 	dropout: 0.4
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇████
test_accuracy,▁▁▁▁▂▄▄▄▄▄▄▄▄▅▅▅▅▅▄▅▅▅█▅▇▇▇▇▇▇▇▇▇█▄▅▄▇▄▅
test_loss,█▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▄▅▇▅█▄▅▆▆▅▅▇▆▅▆▅▅▆▇▆▅▅▇▇█▇▆▇▅▆▆█▆▇▇▆▇
train_loss,█▇▅▅▅▃▃▃▂▃▂▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▂▆▅▆▆▆▆▆▆▆▅▆▆▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▇▇██▇█▇█
val_loss,█▅▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,276.0
test_accuracy,0.65
test_loss,0.56305
train_accuracy,0.76355
train_loss,0.55798
val_accuracy,0.71373
val_loss,0.61206


wandb: Agent Starting Run: 3unsa4p3 with config:
wandb: 	depth: 10
wandb: 	dim: 16
wandb: 	dropout: 0.3
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇▇██
test_accuracy,▁▄▄▆▆▆▆▆▇▆▇█▇█▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
test_loss,█▆▆▆▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇▇▇█▇█▇▇▇
train_loss,█▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁
val_accuracy,▁▃▃▃▄▂▆▄▆▆▅▆▆▇▇█████████████▇███▇▇██▇███
val_loss,█▅▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,0.68877
train_accuracy,0.71429
train_loss,0.73084
val_accuracy,0.6549
val_loss,0.70453


wandb: Agent Starting Run: ywrmodyj with config:
wandb: 	depth: 10
wandb: 	dim: 64
wandb: 	dropout: 0.2
wandb: 	heads: 2
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▂▂▂▂▁▂▃▃▃▄▄▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▇█▇▇▇
test_loss,████▇▇▆▅▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train_accuracy,▂▂▁▂▃▃▃▃▃▄▆▇▅▅▇▆▇▆▇▆▆▆▇█▇▇▇▇▆▇█▆▇█▇▇█▇▇█
train_loss,██▇▆▆▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁
val_accuracy,▁▁▂▅▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
val_loss,█▇▇▇▆▅▅▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.525
test_loss,0.84933
train_accuracy,0.59113
train_loss,0.88524
val_accuracy,0.59216
val_loss,0.85726


wandb: Agent Starting Run: d40ylvz2 with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	dropout: 0.1
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁▂▂▄▅▄▅▆▇▆▅▇▇▆▇▅▇▇█▇██▆▇▇▇▆█▆▇▆▅▆▅▆▆▆▇▇▆
test_loss,█▇▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▃▅▅▆▆▆▇▅█▇▆▇▇▇█▇▇▇▇▅▆▇▇█▆▇▇█▇▇▇▇▆▇▇▇▇▇
train_loss,█▇▇▆▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▄▇▆▆▆▇▇▇▇█▇█▇▇█▇██▇██▇███████▇▇▇█████
val_loss,█▇▆▅▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,133.0
test_accuracy,0.65
test_loss,0.57528
train_accuracy,0.73892
train_loss,0.51047
val_accuracy,0.7098
val_loss,0.63195


wandb: Agent Starting Run: blxvplkv with config:
wandb: 	depth: 10
wandb: 	dim: 128
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█████
test_accuracy,▁▁▂▂▂▃▄▄▄▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▆▇▇▇▇█▇▇▇▇▇▇▇▆
test_loss,██▇▆▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▅▆▅▆▇▆▇▇▇▇▇▇▇▆▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇
train_loss,██▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▃▄▄▄▆▆▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▇█▇█████████████
val_loss,██▇▆▆▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,192.0
test_accuracy,0.675
test_loss,0.57853
train_accuracy,0.72906
train_loss,0.55578
val_accuracy,0.70588
val_loss,0.61394


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: piv3r11g with config:
wandb: 	depth: 3
wandb: 	dim: 128
wandb: 	dropout: 0.1
wandb: 	heads: 2
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
test_accuracy,▁▃▅▅▅▄▇▅▅▅▅▆▅▅▆▅▅▅▆▅▅▅▅▅▅▆▅▅▅▇▅▇▆▆▇▅█▆█▇
test_loss,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂
train_accuracy,▁▄▅▇▅▆▅▆▆▅▆▅▅▅▇▅▆▅▅▆▆▆▇▇▇▆▇▇█▅█▇▆▅█▇▇█▇▆
train_loss,█▆▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▆▇▇▇▇█▇█▇█▇▇█▇████▇▇█▇█▇█▇▇▇▇█▇▇▇█▇█
val_loss,█▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂

0,1
epoch,61.0
test_accuracy,0.675
test_loss,0.61095
train_accuracy,0.75369
train_loss,0.47839
val_accuracy,0.72941
val_loss,0.66798


wandb: Agent Starting Run: zvrjsnse with config:
wandb: 	depth: 3
wandb: 	dim: 64
wandb: 	dropout: 0.2
wandb: 	heads: 32
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
test_accuracy,▁▃▄▄▅▄▅▅▄▅▅▅▅▆▅▅▅▄▆▅▅▇▆▅▆▇█▇▇▇▇▇▇▇▇▇▅▆▆▇
test_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂
train_accuracy,▁▄▆▅▆▇▆▅▆▅▇▇▄▆▆▇▅▆▆▆▇▆▇▇▆▇▆▅▆▆▆▇▆█▇▇▇█▆█
train_loss,█▇▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▄▅▆▆▆▆▇▇▇▇▇▇███▇█▇▇▇█▇██▇▇▆▇▇▆▇▇▇▇▇▇▆▆
val_loss,█▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂

0,1
epoch,77.0
test_accuracy,0.75
test_loss,0.60446
train_accuracy,0.72414
train_loss,0.48693
val_accuracy,0.72549
val_loss,0.6539


wandb: Agent Starting Run: 9cr11b8h with config:
wandb: 	depth: 20
wandb: 	dim: 64
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇▇█████
test_accuracy,▁▁▁▁▁▃▃▃▃▃▄▅▅▅▅▇▅▅▅▅▅▅▅▅▅▃▄▄▃▃▂▂▂▄▅▆▆▇▇█
test_loss,███▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁
train_accuracy,▁▁▃▃▁▂▂▂▁▃▂▃▁▃▄▄▃▄▃▆▂▃▆▄▅▅▇▅▇▆▇▆▆▆▆▇▄███
train_loss,███▇▆▆▆▅▅▅▅▅▅▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▁▁▁▁▃▃▃▃▃▃▃▆▃▃▅▅▅▅▇▅▅▅▆▅▄▄▇▇▇▇█▇▇▇█████
val_loss,█▇▇▇▆▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.4
test_loss,1.15611
train_accuracy,0.42365
train_loss,1.23186
val_accuracy,0.41961
val_loss,1.19098


wandb: Agent Starting Run: pi56wwam with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇█
test_accuracy,▁▄▄▄▃▄▅▃▅▆▃▅▅▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▇▇█▇▇▇█████▇
test_loss,████▇▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
train_accuracy,▁▁▂▂▂▃▄▅▄▅▅▅▅▅▆▆▆▄▅▅▅▆▇▇▇▇▇▇▆█▆█▇█▇█▇▇█▇
train_loss,█████▇▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▁▁▃▃▃▆▅▆▆▆▆▅▆▆▆▆▅▆▆▇▇▇▇▇▇▇████▇▇▇▇█████
val_loss,██▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁

0,1
epoch,299.0
test_accuracy,0.6
test_loss,0.80058
train_accuracy,0.62562
train_loss,0.84204
val_accuracy,0.58039
val_loss,0.79897


wandb: Agent Starting Run: 0rro7n4t with config:
wandb: 	depth: 20
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 2
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
test_accuracy,▂▁▃▄▄▄▄▆▆▆▆▇▆█▆▇▆▆▇▆▇▆▆▆▆▆▆██▇▆▇▆▇▇▆█▇▇█
test_loss,████▆▅▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▂▃▆▅▆▆▇▆▆▇▆▆▇▇▇▇▆▇▇▇▇▇▆█▇▆▇▆▇▇▇█▇▇▇▇▇
train_loss,█████▅▄▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▃▃▅▆▆▇▇▇▇▇▇▇█▇███▇█▇██▇█▇█▇██████████▇
val_loss,████▆▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,58.0
test_accuracy,0.725
test_loss,0.6373
train_accuracy,0.73892
train_loss,0.50807
val_accuracy,0.70196
val_loss,0.67272


wandb: Agent Starting Run: 6i77gzj1 with config:
wandb: 	depth: 10
wandb: 	dim: 16
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇█████
test_accuracy,▁▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▆▆▇▆▅▅▇▇█▇▇▇▇█▆▆▆▆▅▇
test_loss,█▅▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▇▅▅▆▅▅▅▅▅▇▆▇▄▇▆▇▆▇▇█▆▆▇▇██▇▇▆▇▇▆▆██▇▅▆
train_loss,█▇▇▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆▆▆▇▇▇█▇█▇▇▇▇▇▇▇▇█████████████████▇████
val_loss,█▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,188.0
test_accuracy,0.6
test_loss,0.59746
train_accuracy,0.69458
train_loss,0.56509
val_accuracy,0.67843
val_loss,0.64327


wandb: Agent Starting Run: xg6iapcz with config:
wandb: 	depth: 3
wandb: 	dim: 128
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█████
test_accuracy,▁▁▂▃▄▄▄▅▅▅▅▅▆▆▇▆▆▇██▇▇█▇██▇█▇▇▇▆▆▇▆▇▇▇▇▇
test_loss,█▆▆▅▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▄▄▅▅▅▇▅▇▇▇▇▇▇▇▇▆▆▇▇▆█▇▇▇██▆▇▇▇▇█▆█▇█▇▆
train_loss,█▆▅▅▅▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▂▃▃▅▆▇▇▆▇▇▇▇▇▇▇▇▇███████▇█████████████
val_loss,█▇▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,211.0
test_accuracy,0.6
test_loss,0.56011
train_accuracy,0.7734
train_loss,0.52098
val_accuracy,0.72157
val_loss,0.59495


wandb: Agent Starting Run: nq006e3j with config:
wandb: 	depth: 20
wandb: 	dim: 512
wandb: 	dropout: 0.1
wandb: 	heads: 2
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
test_accuracy,▂▂▂▄▁▃▁▂▁▃▆▆▆▇▆▅▄▄▅▇▇█▇▇▇█▂▂▄▂▂▂▂▂▂▂▂▂▂▂
test_loss,▅▅▅▃▅▄▃▃▃▃▂▄▄▃▃▂▁▁▁▂▁▁▂▂█▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
train_accuracy,▂▃▃▁▁▁▄▅▆▅▇▆▆▆▃▇▄▇█▆██▃▄▄▄▅▅▄▄▆▃▄▄▄▄▅▄▄▅
train_loss,▇▇▇▇▇▆▇▅▃▅▄▄▃▄▆▃▃▁▁▂▁▁▁▅█▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅
val_accuracy,▂▂▂▂▁▂▃▇▇█▇▄▃▄▄▅█▆▇██████▇▁▃▅▃▆▆▆▆▆▆▆▆▆▆
val_loss,███████▅▄▆▄▇▅▅▅▃▂▂▃▃▁▁▃▃▂▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆

0,1
epoch,64.0
test_accuracy,0.25
test_loss,1.31117
train_accuracy,0.35961
train_loss,1.30292
val_accuracy,0.38824
val_loss,1.31035


wandb: Agent Starting Run: 799bpznn with config:
wandb: 	depth: 10
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
test_accuracy,▂▁▂▃▆▆▇▇▇▆▆▇▆▆▅▆▇▇▅▇▇▇▇▇▆▇█▇▇▇▆▇▇▇▆▇▇▇▇▇
test_loss,█▇▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▅▅▆▇▇▆▆▆▇▆▆▇▆▇▆█▆▇█████▇▆█▇▆▇▇▇█▆▆▇▆█
train_loss,██▆▄▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▄▆▇▇▇▇▇▇▇▇▇▇███▇▇█████▇███▇█▇█▇▇████
val_loss,█▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,133.0
test_accuracy,0.65
test_loss,0.56905
train_accuracy,0.7734
train_loss,0.51375
val_accuracy,0.72941
val_loss,0.59985


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: xiyfcymi with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████
test_accuracy,▁▁▁▄▄▄▄▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇██████
test_loss,███▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▄▄▅▆▆▆▆▆▇▆▇▆▆▇▆▆▆▇█▆▇▆▇█▆▇▆▇▇▇▆▇▆▆▇▇▇
train_loss,██▇▇▇▆▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▃▃▅▅▅▅▅▇▇████▇▇▇▇▇▇█▇█████▇██▇███▇██▇▇
val_loss,██▇▇▇▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.625
test_loss,0.67706
train_accuracy,0.72414
train_loss,0.69333
val_accuracy,0.64706
val_loss,0.70214


wandb: Agent Starting Run: c0abyqam with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.4
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
test_accuracy,▂▁▂▅▆▇▇▆▇▇█▇▇▇▇▇▆▇███▇▇██▇▇▇▇█▇▇▇▇▇▇▇▇▇▇
test_loss,███▅▄▂▂▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▃▅▇▇█▇▇▇▇▇▇▇▇▇▇█▇█▆█▇█▇▇▇█▇▇████▇███▇▇
train_loss,███▆▅▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▄▆▇██████████████████████████▇██████
val_loss,█▅▄▃▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁

0,1
epoch,74.0
test_accuracy,0.65
test_loss,0.63143
train_accuracy,0.74877
train_loss,0.56208
val_accuracy,0.69804
val_loss,0.68545


wandb: Agent Starting Run: g8np9v68 with config:
wandb: 	depth: 20
wandb: 	dim: 512
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▂▂▃▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇█████
test_accuracy,▁▁▁▁▁▂▅▅▆▇█▇█████████▇▇█████▇▇▇█████▇▇▇█
test_loss,████▇▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▃▅▆▆▆▆▆▇▇▇▇▇▆▇▇▆▇▇▆█▆▆▇█▇▇▇▆▇▇█▇██▇██
train_loss,███▆▆▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▁▄▃▅▆▆███▇██████████████████████████
val_loss,█████▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,143.0
test_accuracy,0.65
test_loss,0.60025
train_accuracy,0.72906
train_loss,0.5535
val_accuracy,0.7098
val_loss,0.6314


wandb: Agent Starting Run: 1vuc0y32 with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.3
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
test_accuracy,▂▂▂▂▂▂▁▁▂▂▃▃▃▃▄▄▄▆▅▅▆▆▇█▇▇▇▇▇▇▇▇▇▇███▇█▇
test_loss,█▇▇▇▇▆▆▆▆▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▂▂▁▂▃▅▄▅▆▄▄▆▅▅▅▆▅▅▆▅▅▆▆▇▆▆▇█▆▇▇█▇█▇█▇▇▇▇
train_loss,███▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▁▁▁▁▁▂▃▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇███████████████
val_loss,█▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.575
test_loss,0.73953
train_accuracy,0.6601
train_loss,0.77345
val_accuracy,0.64314
val_loss,0.75439


wandb: Agent Starting Run: k4xzvip0 with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.3
wandb: 	heads: 8
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy,▁▂▃▂▄▄▆▆▇▆▇▇▆█▆▆▆▇▆▇▆▇▇▆▇▇▇██▇█▇▇████▇██
test_loss,███▇▅▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▂▅▆▆▆▆▇▇▇▇▆▇▇▇▇▇▇▇▇▇█▇▇▇▇█▇▇▇▇██▇▇▇▇▇
train_loss,█████▅▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▁▂▁▂▅▆▆▇▇▇█████████████████████████████
val_loss,█████▅▄▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,58.0
test_accuracy,0.675
test_loss,0.61506
train_accuracy,0.74384
train_loss,0.55819
val_accuracy,0.69412
val_loss,0.64891


wandb: Agent Starting Run: 2lpu8alj with config:
wandb: 	depth: 10
wandb: 	dim: 64
wandb: 	dropout: 0.3
wandb: 	heads: 2
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
test_accuracy,▁▄▅██▇▇▇██▇▇█▇▇▇▇▇█▇▇█▇▇████▇▇████▇██▇█▇
test_loss,█▆▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▅▅▅▆▆▇▇▆▇▆▆▇▇▇▇▇▇▇▇▇█▇█▇█▇▇▇▇▇▇▇▇▇█▇██
train_loss,█▆▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▇▇▇▇▇▇▇█▇▇▇▇▇█▇██▇█▇▇██████████████
val_loss,█▇▅▄▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,73.0
test_accuracy,0.675
test_loss,0.58267
train_accuracy,0.80788
train_loss,0.5276
val_accuracy,0.71373
val_loss,0.62856


wandb: Agent Starting Run: xuuzy81e with config:
wandb: 	depth: 3
wandb: 	dim: 16
wandb: 	dropout: 0.1
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 128
wandb: 	patch_size: 3748


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▇▇▇▇▇▇███
test_accuracy,▁▁▁▁▁▂▃▃▃▃▃▃▃▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████
test_loss,██▇▇▆▆▅▅▅▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train_accuracy,▁▂▂▂▁▂▂▃▃▃▂▃▄▄▄▄▄▆▆▆▇▅▆▇▆▆▇█▅▇▇▇▇▆███▇█▇
train_loss,████▇▆▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████
val_loss,█▇▆▆▆▅▅▅▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.525
test_loss,1.07223
train_accuracy,0.53202
train_loss,1.08107
val_accuracy,0.61176
val_loss,1.08731


wandb: Agent Starting Run: b4ogjscg with config:
wandb: 	depth: 10
wandb: 	dim: 128
wandb: 	dropout: 0.2
wandb: 	heads: 32
wandb: 	lr: 3e-05
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
test_accuracy,▁▁▂▄▆▅▅▇▅▇▅▆▇▅▄▇▆▅▄▇▄▆▄▇▅▇▆▅▅▇▆▆▇█▇▆▇▆█▅
test_loss,█▇▅▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▅▆▅▅▆▆▅▇▅▅▆▇▆▆▅▇▇▇▇▆▇▆▇▅██▇█▇▇▇█▆█▇▆▇█
train_loss,█▇▆▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▄▅▇▇▇▇▇▇▇▇█▇▇▇▇▇██████████████████▇▇██
val_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂

0,1
epoch,56.0
test_accuracy,0.7
test_loss,0.60144
train_accuracy,0.78325
train_loss,0.48733
val_accuracy,0.7098
val_loss,0.6656


wandb: Agent Starting Run: mt3k76cq with config:
wandb: 	depth: 10
wandb: 	dim: 16
wandb: 	dropout: 0.4
wandb: 	heads: 2
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


0,1
epoch,▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
test_accuracy,▂▂▁▃▃▄▄▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████▇▇
test_loss,██▇▇▆▆▆▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▂▄▄▄▄▆▄▇▅▆▅▇▇▆▆▇▇▇▇▆▆▇▇█▇▇█████▇█▇▇██
train_loss,████▇▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▂▃▄▆▆▆▇▇▇▇▇▇▇██▇████▇▇▇▇▇▇████████████
val_loss,█▇▇▇▆▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.575
test_loss,0.69607
train_accuracy,0.61576
train_loss,0.76508
val_accuracy,0.67059
val_loss,0.71544


wandb: Agent Starting Run: u8j043nd with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	dropout: 0.3
wandb: 	heads: 8
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 256
wandb: 	patch_size: 3748


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁▁▃▃▂▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇██▇
test_loss,███▇▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▅▆▆▆▆▇▆▇▇█▇▆▆▇▆▇▆▇▆▇▇█▇
train_loss,██▇▆▆▆▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▃▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████████████████████
val_loss,███▇▆▆▆▅▅▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,299.0
test_accuracy,0.7
test_loss,0.59934
train_accuracy,0.70443
train_loss,0.61562
val_accuracy,0.71765
val_loss,0.6135


wandb: Agent Starting Run: pfgfdcby with config:
wandb: 	depth: 20
wandb: 	dim: 16
wandb: 	dropout: 0.3
wandb: 	heads: 32
wandb: 	lr: 3e-07
wandb: 	mlp_dim: 64
wandb: 	patch_size: 3748


Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\sit

Run pfgfdcby errored:
Traceback (most recent call last):
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\wandb\agents\pyagent.py", line 306, in _run_job
    self._function()
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\1700346631.py", line 27, in sweep_train
    model_vit = VisionTransformer1D(
                ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Temp\ipykernel_11012\2276252545.py", line 145, in __init__
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
                                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jcwin\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\transformer.py", line 590, in __init__
    self.self_attn = MultiheadAttention(d_mo

In [None]:
import wandb

# Define the hyperparameters
num_classes = 4
num_epochs = 300
patience = 30

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [3e-7, 3e-6, 3e-5]},
        "patch_size": {"values": [3748]}, # Adjusted for faster runs
        "dim": {"values": [32, 128, 512]},
        "heads": {"values": [2, 8, 32]},
        "mlp_dim": {"values": [64, 128, 512]}, # Adjusted for faster runs
        #"batch_size": {"values": [64, 128, 256]},
        "depth": {"values": [3, 10, 20]},
        "dropout": {"values": [0., 0.2, 0.4]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_vit = VisionTransformer1D(
            num_classes=num_classes,
            patch_size=config.patch_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout
        )
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_vit(
            model_vit,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            lr=config.lr,
            max_patience=patience,
            device='cuda'
        )


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-classification-vit")
wandb.agent(sweep_id, function=sweep_train, count=200)


In [None]:
import wandb

# Define the hyperparameters
num_classes = 4
num_epochs = 300
patience = 30

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [3e-7, 3e-6, 3e-5]},
        "patch_size": {"values": [16, 256, 3748]},
        "dim": {"values": [32, 64, 128, 256, 512]},
        "heads": {"values": [2, 8, 32]},
        "mlp_dim": {"values": [64, 128, 512]}, # Adjusted for faster runs
        #"batch_size": {"values": [64, 128, 256]},
        "depth": {"values": [3, 10, 20]},
        "dropout": {"values": [0.1, 0.2, 0.3, 0.4]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_vit = VisionTransformer1D(
            num_classes=num_classes,
            patch_size=config.patch_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout
        )
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_vit(
            model_vit,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            lr=config.lr,
            max_patience=patience,
            device='cuda'
        )


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-classification-vit")
wandb.agent(sweep_id, function=sweep_train, count=50)


In [None]:
import wandb
import torch
from torch import nn
from einops import rearrange, repeat

# Using Spectrum Transformer: An Attention-Based Wideband Spectrum Detector

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.attend = nn.Softmax(dim=-1)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ])
            for _ in range(depth)
        ])
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class VisionTransformer1D(nn.Module):
    def __init__(self, input_size=3748, num_classes=4, patch_size=5, dim=128, depth=12, heads=16, mlp_dim=256, pool='cls', dim_head=64, dropout=0.2, emb_dropout=0.1):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (input_size + patch_size - 1) // patch_size
        self.to_patch_embedding = nn.Sequential(
            nn.Linear(patch_size, dim)
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        b, _, seq_len = x.shape
        pad_len = (self.patch_size - (seq_len % self.patch_size)) % self.patch_size
        x = nn.functional.pad(x, (0, pad_len))
        x = x.view(b, -1, self.patch_size)
        x = self.to_patch_embedding(x)

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(x.size(1))]
        x = self.dropout(x)
        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


# Define the hyperparameters
num_classes = 4
num_epochs = 300
patience = 30

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [3e-7, 3e-6, 3e-5]},
        "patch_size": {"values": [16, 256, 3748]},
        "dim": {"values": [32, 64, 128, 256, 512]},
        "heads": {"values": [2, 8, 32]},
        "mlp_dim": {"values": [128, 256, 512]},
        #"batch_size": {"values": [64, 128, 256]},
        "depth": {"values": [3, 10, 20]},
        "dropout": {"values": [0.1, 0.2, 0.3, 0.4]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_vit = VisionTransformer1D(
            num_classes=num_classes,
            patch_size=config.patch_size,
            dim=config.dim,
            depth=config.depth,
            heads=config.heads,
            mlp_dim=config.mlp_dim,
            dropout=config.dropout
        )
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_vit(
            model_vit,
            train_loader,
            val_loader,
            test_loader,
            num_epochs=num_epochs,
            lr=config.lr,
            max_patience=patience,
            device='cuda'
        )


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-classification-vit")
wandb.agent(sweep_id, function=sweep_train, count=500)


wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: ukgexvm9
Sweep URL: https://wandb.ai/joaoc-university-of-southampton/spectra-classification-vit/sweeps/ukgexvm9


wandb: Agent Starting Run: bw3zdkej with config:
wandb: 	depth: 20
wandb: 	dim: 128
wandb: 	dropout: 0.2
wandb: 	heads: 8
wandb: 	lr: 3e-06
wandb: 	mlp_dim: 512
wandb: 	patch_size: 16
wandb: Currently logged in as: joaoc (joaoc-university-of-southampton). Use `wandb login --relogin` to force relogin


: 

# SpectraTR Code

In [20]:
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        # self.to_qkv = nn.Linear(dim, inner_dim , bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # qkv = self.to_qkv(x).chunk(1, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        #self.to_patch_embedding = nn.Sequential(
        #    Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        #    nn.Linear(patch_dim, dim),
        #)
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (p s) -> b (p c) s', p=patch_dim),
            nn.Linear(patch_dim, dim)
        )


        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img): 
        print(img.shape)
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [24]:
   
torch.cuda.empty_cache()
# Initialize WandB project
#wandb.init(project="spectra-classification-vit", entity="joaoc-university-of-southampton")
# Initialize and train the model
model_vit = ViT(patch_size=2, image_size=(3748), num_classes=4, dim=64, depth=2, heads=8, mlp_dim=128, dropout=0.1)
print(model_vit)
# print the number of parameters
print('Number of parameters: ', sum(p.numel() for p in model_vit.parameters()))

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (p s) -> b (p c) s', p=4)
    (1): Linear(in_features=4, out_features=64, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-1): 2 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=64, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=True)
              (1): GELU(approximate='none')
              (2): Dropout

In [25]:
trained_model = train_model_vit(model_vit, train_loader, val_loader, test_loader, num_epochs=50, lr=1e-4, patience=10)

# Save the model and finish WandB session
wandb.finish()

torch.Size([128, 1, 3748])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x937 and 4x64)