In [1]:
#!pip install torch torchvision wandb


In [2]:
import wandb

# Option 1: Interactive login (prompts for API key)
wandb.login()

# Option 2: Programmatic login using API key
# Replace 'your-api-key' with your actual W&B API key
wandb.login(key='669c0c4ffcd8985791f84ea79983b116c50b2624')


[34m[1mwandb[0m: Currently logged in as: [33m142201020[0m ([33m142201020-indian-institute-of-technology-palakkad[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/hemanth/.netrc


True

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


In [4]:
from torch.utils.data import Subset
import random

def get_cifar_loaders(dataset_name, batch_size=128, subset_fraction=0.2, seed=42):
    """
    subset_fraction: float between 0 and 1, e.g., 0.2 = use 20% of the dataset
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ])
    
    if dataset_name == "CIFAR10":
        full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        full_testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
    elif dataset_name == "CIFAR100":
        full_trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        full_testset  = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100
    else:
        raise ValueError("Unsupported dataset")
    
    # Determine subset sizes
    train_size = int(len(full_trainset) * subset_fraction)
    test_size  = int(len(full_testset) * subset_fraction)
    
    random.seed(seed)
    train_indices = random.sample(range(len(full_trainset)), train_size)
    test_indices  = random.sample(range(len(full_testset)), test_size)
    
    trainset = Subset(full_trainset, train_indices)
    testset  = Subset(full_testset, test_indices)
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader  = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return trainloader, testloader, num_classes


In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [6]:
def train_model(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    return running_loss/total, correct/total

def evaluate_model(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return running_loss/total, correct/total


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def sequential_training(sequence, epochs=100):
    for i, dataset_name in enumerate(sequence):
        run_name = f"seq_{'_'.join(sequence)}_part{i+1}_{dataset_name}"
        wandb.init(project="CIFAR-sequential", name=run_name, config={"dataset": dataset_name, "epochs": epochs})
        
        trainloader, testloader, num_classes = get_cifar_loaders(dataset_name)
        model = SimpleCNN(num_classes=num_classes).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        
        for epoch in range(epochs):
            train_loss, train_acc = train_model(model, trainloader, criterion, optimizer, device)
            test_loss, test_acc   = evaluate_model(model, testloader, criterion, device)
            
            # Log to W&B
            wandb.log({
                "epoch": epoch+1,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "test_loss": test_loss,
                "test_acc": test_acc
            })
            print(f"[{dataset_name}] Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")
        
        wandb.finish()


In [8]:
sequential_training(["CIFAR100", "CIFAR10"], epochs=100)


[CIFAR100] Epoch 1/100 - Train Acc: 0.0666, Test Acc: 0.1210
[CIFAR100] Epoch 2/100 - Train Acc: 0.1601, Test Acc: 0.1620
[CIFAR100] Epoch 3/100 - Train Acc: 0.2269, Test Acc: 0.1900
[CIFAR100] Epoch 4/100 - Train Acc: 0.2901, Test Acc: 0.2265
[CIFAR100] Epoch 5/100 - Train Acc: 0.3531, Test Acc: 0.2515
[CIFAR100] Epoch 6/100 - Train Acc: 0.4190, Test Acc: 0.2470
[CIFAR100] Epoch 7/100 - Train Acc: 0.4847, Test Acc: 0.2625
[CIFAR100] Epoch 8/100 - Train Acc: 0.5470, Test Acc: 0.2475
[CIFAR100] Epoch 9/100 - Train Acc: 0.6175, Test Acc: 0.2505
[CIFAR100] Epoch 10/100 - Train Acc: 0.6892, Test Acc: 0.2455
[CIFAR100] Epoch 11/100 - Train Acc: 0.7582, Test Acc: 0.2575
[CIFAR100] Epoch 12/100 - Train Acc: 0.8246, Test Acc: 0.2445
[CIFAR100] Epoch 13/100 - Train Acc: 0.8796, Test Acc: 0.2430
[CIFAR100] Epoch 14/100 - Train Acc: 0.9267, Test Acc: 0.2515
[CIFAR100] Epoch 15/100 - Train Acc: 0.9456, Test Acc: 0.2355
[CIFAR100] Epoch 16/100 - Train Acc: 0.9608, Test Acc: 0.2525
[CIFAR100] Epoch 

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
test_acc,▁▃▄▇▇█▇▇▇▇▇▇█▇▆▇▇▇██████████▆▇▇█▇▇▇▇▆▆▆▇
test_loss,▁▁▁▁▃▄▅▅▆▆▆▆▅▆▆▇▇▇▇▇▇▇█▇████▄▆▇▇▇▇█▆▇▇██
train_acc,▁▂▂▅▅▇▇████▇████████████████████████████
train_loss,█▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
test_acc,0.2485
test_loss,9.42936
train_acc,0.9999
train_loss,0.00071


100%|██████████| 170M/170M [00:37<00:00, 4.60MB/s] 


[CIFAR10] Epoch 1/100 - Train Acc: 0.3591, Test Acc: 0.4665
[CIFAR10] Epoch 2/100 - Train Acc: 0.5090, Test Acc: 0.5045
[CIFAR10] Epoch 3/100 - Train Acc: 0.5563, Test Acc: 0.5400
[CIFAR10] Epoch 4/100 - Train Acc: 0.6024, Test Acc: 0.5655
[CIFAR10] Epoch 5/100 - Train Acc: 0.6603, Test Acc: 0.5740
[CIFAR10] Epoch 6/100 - Train Acc: 0.6913, Test Acc: 0.5725
[CIFAR10] Epoch 7/100 - Train Acc: 0.7246, Test Acc: 0.5810
[CIFAR10] Epoch 8/100 - Train Acc: 0.7595, Test Acc: 0.5945
[CIFAR10] Epoch 9/100 - Train Acc: 0.7935, Test Acc: 0.5975
[CIFAR10] Epoch 10/100 - Train Acc: 0.8311, Test Acc: 0.6120
[CIFAR10] Epoch 11/100 - Train Acc: 0.8749, Test Acc: 0.5975
[CIFAR10] Epoch 12/100 - Train Acc: 0.9016, Test Acc: 0.6035
[CIFAR10] Epoch 13/100 - Train Acc: 0.9343, Test Acc: 0.5800
[CIFAR10] Epoch 14/100 - Train Acc: 0.9548, Test Acc: 0.5830
[CIFAR10] Epoch 15/100 - Train Acc: 0.9657, Test Acc: 0.6060
[CIFAR10] Epoch 16/100 - Train Acc: 0.9870, Test Acc: 0.6045
[CIFAR10] Epoch 17/100 - Train Ac

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
test_acc,▁▅███▆▇▇████████████████████████████████
test_loss,▁▁▁▁▂▂▂▃▄▄▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█████
train_acc,▁▃▄▅▆██████▇████████████████████████████
train_loss,█▇▆▅▅▃▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
test_acc,0.612
test_loss,3.7414
train_acc,1.0
train_loss,2e-05


In [9]:
sequential_training(["CIFAR10", "CIFAR100"], epochs=100)

[CIFAR10] Epoch 1/100 - Train Acc: 0.3565, Test Acc: 0.4400
[CIFAR10] Epoch 2/100 - Train Acc: 0.4922, Test Acc: 0.4950
[CIFAR10] Epoch 3/100 - Train Acc: 0.5579, Test Acc: 0.5395
[CIFAR10] Epoch 4/100 - Train Acc: 0.6037, Test Acc: 0.5660
[CIFAR10] Epoch 5/100 - Train Acc: 0.6498, Test Acc: 0.5780
[CIFAR10] Epoch 6/100 - Train Acc: 0.6840, Test Acc: 0.5960
[CIFAR10] Epoch 7/100 - Train Acc: 0.7089, Test Acc: 0.6020
[CIFAR10] Epoch 8/100 - Train Acc: 0.7498, Test Acc: 0.6075
[CIFAR10] Epoch 9/100 - Train Acc: 0.7845, Test Acc: 0.6335
[CIFAR10] Epoch 10/100 - Train Acc: 0.8143, Test Acc: 0.6160
[CIFAR10] Epoch 11/100 - Train Acc: 0.8437, Test Acc: 0.5965
[CIFAR10] Epoch 12/100 - Train Acc: 0.8742, Test Acc: 0.6285
[CIFAR10] Epoch 13/100 - Train Acc: 0.9080, Test Acc: 0.6245
[CIFAR10] Epoch 14/100 - Train Acc: 0.9310, Test Acc: 0.6280
[CIFAR10] Epoch 15/100 - Train Acc: 0.9537, Test Acc: 0.6110
[CIFAR10] Epoch 16/100 - Train Acc: 0.9717, Test Acc: 0.6295
[CIFAR10] Epoch 17/100 - Train Ac

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█
test_acc,▁▅▆▇▇▇██████████████████████████████████
test_loss,▂▂▁▁▁▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█████
train_acc,▁▂▅▅▅▇▇▇▇███████████████████████████████
train_loss,█▇▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
test_acc,0.6245
test_loss,3.63838
train_acc,1.0
train_loss,3e-05


[CIFAR100] Epoch 1/100 - Train Acc: 0.0666, Test Acc: 0.1115
[CIFAR100] Epoch 2/100 - Train Acc: 0.1611, Test Acc: 0.1715
[CIFAR100] Epoch 3/100 - Train Acc: 0.2298, Test Acc: 0.2035
[CIFAR100] Epoch 4/100 - Train Acc: 0.2884, Test Acc: 0.2280
[CIFAR100] Epoch 5/100 - Train Acc: 0.3517, Test Acc: 0.2435
[CIFAR100] Epoch 6/100 - Train Acc: 0.4092, Test Acc: 0.2515
[CIFAR100] Epoch 7/100 - Train Acc: 0.4828, Test Acc: 0.2610
[CIFAR100] Epoch 8/100 - Train Acc: 0.5442, Test Acc: 0.2595
[CIFAR100] Epoch 9/100 - Train Acc: 0.6108, Test Acc: 0.2445
[CIFAR100] Epoch 10/100 - Train Acc: 0.6746, Test Acc: 0.2530
[CIFAR100] Epoch 11/100 - Train Acc: 0.7423, Test Acc: 0.2470
[CIFAR100] Epoch 12/100 - Train Acc: 0.8174, Test Acc: 0.2430
[CIFAR100] Epoch 13/100 - Train Acc: 0.8618, Test Acc: 0.2510
[CIFAR100] Epoch 14/100 - Train Acc: 0.9064, Test Acc: 0.2460
[CIFAR100] Epoch 15/100 - Train Acc: 0.9353, Test Acc: 0.2515
[CIFAR100] Epoch 16/100 - Train Acc: 0.9671, Test Acc: 0.2370
[CIFAR100] Epoch 

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
test_acc,▁▅█▇▆███▇▄████████▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▆▆
test_loss,▂▁▁▁▁▄▄▄▅▆▆▅▅▆▆▆▆▇▇▇▇▇▇▇▇▆▆▇▇▇▇▇▇▇▇█████
train_acc,▁▂▄▄▆▇███████████████████▇▇█████████████
train_loss,█▇▆▃▂▂▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,100.0
test_acc,0.243
test_loss,9.98345
train_acc,0.9998
train_loss,0.00059
