In [3]:
!pip install wandb -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import wandb
import matplotlib.pyplot as plt
import numpy as np

wandb.login()

class CustomCIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None):
        # Download data to Colab's local storage
        self.cifar_raw = torchvision.datasets.CIFAR10(root=root, train=train, download=True)
        self.data = self.cifar_raw.data
        self.targets = self.cifar_raw.targets
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        target = self.targets[idx]
        img = transforms.ToPILImage()(img)
        if self.transform:
            img = self.transform(img)
        return img, target

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Block 1
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)

        # Block 2
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        # Classifier
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

def count_flops(model, input_size=(1, 3, 32, 32)):
    flops = []

    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
        bias_ops = 1 if self.bias is not None else 0
        params = kernel_ops + bias_ops
        flops.append(batch_size * params * output_channels * output_height * output_width)

    def linear_hook(self, input, output):
        batch_size = input[0].size(0)
        weight_ops = self.weight.nelement()
        bias_ops = self.bias.nelement() if self.bias is not None else 0
        flops.append(batch_size * (weight_ops + bias_ops))

    hooks = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            hooks.append(layer.register_forward_hook(conv_hook))
        elif isinstance(layer, nn.Linear):
            hooks.append(layer.register_forward_hook(linear_hook))

    # Dummy pass
    dummy_input = torch.randn(input_size).to(next(model.parameters()).device)
    model(dummy_input)

    for hook in hooks: hook.remove()
    return sum(flops)

def plot_grad_flow(named_parameters):
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n) and (p.grad is not None):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().cpu().item())

    fig = plt.figure(figsize=(10, 5))
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.tight_layout()
    return fig

def train():
    # Initialize WandB
    run = wandb.init(project="colab-cifar10", name="CNN_Run_Colab")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on: {device}")

    # Hyperparams
    BATCH_SIZE = 128
    EPOCHS = 25
    LR = 0.001

    # Data
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset = CustomCIFAR10(root='./data', train=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    model = SimpleCNN().to(device)

    # Calculate FLOPs
    total_flops = count_flops(model)
    print(f"Total FLOPs: {total_flops / 1e6:.2f} Million")
    wandb.log({"Total FLOPs (M)": total_flops / 1e6})

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    # Track old weights for visualization
    old_weights = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            old_weights[name] = param.clone().detach()

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            # Log Gradient Flow (Last batch of epoch)
            if i == len(train_loader) - 1:
                grad_fig = plot_grad_flow(model.named_parameters())
                wandb.log({"Gradient Flow": wandb.Image(grad_fig)}, commit=False)
                plt.close(grad_fig)

            optimizer.step()
            running_loss += loss.item()

        # Log Weight Updates
        update_magnitudes = {}
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.requires_grad:
                    new_weight = param
                    update = (new_weight - old_weights[name]).abs().mean().item()
                    update_magnitudes[name] = update
                    old_weights[name] = new_weight.clone().detach() # Update reference

        avg_loss = running_loss / len(train_loader)
        wandb.log({"epoch": epoch + 1, "loss": avg_loss, "weight_updates": update_magnitudes})
        print(f"Epoch [{epoch+1}/{EPOCHS}] Loss: {avg_loss:.4f}")

    print("Training Complete. Check WandB dashboard for charts.")
    wandb.finish()

# Run the training
train()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 33e900f63f144527de34


[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: [32m[41mERROR[0m Invalid API key: API key must have 40+ characters, has 20.
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 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


[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mb22bb016[0m ([33mb22bb016-prom-iit-rajasthan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training on: cuda


100%|██████████| 170M/170M [00:03<00:00, 48.7MB/s]


Total FLOPs: 76.55 Million
Epoch [1/25] Loss: 1.5830
Epoch [2/25] Loss: 1.2803
Epoch [3/25] Loss: 1.1582
Epoch [4/25] Loss: 1.0808
Epoch [5/25] Loss: 1.0143
Epoch [6/25] Loss: 0.9694
Epoch [7/25] Loss: 0.9256
Epoch [8/25] Loss: 0.8941
Epoch [9/25] Loss: 0.8629
Epoch [10/25] Loss: 0.8382
Epoch [11/25] Loss: 0.8092
Epoch [12/25] Loss: 0.7853
Epoch [13/25] Loss: 0.7629
Epoch [14/25] Loss: 0.7451
Epoch [15/25] Loss: 0.7256
Epoch [16/25] Loss: 0.7058
Epoch [17/25] Loss: 0.6913
Epoch [18/25] Loss: 0.6761
Epoch [19/25] Loss: 0.6616
Epoch [20/25] Loss: 0.6478
Epoch [21/25] Loss: 0.6323
Epoch [22/25] Loss: 0.6277
Epoch [23/25] Loss: 0.6108
Epoch [24/25] Loss: 0.6049
Epoch [25/25] Loss: 0.5934
Training Complete. Check WandB dashboard for charts.


0,1
Total FLOPs (M),▁
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
loss,█▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
Total FLOPs (M),76.54734
epoch,25.0
loss,0.5934
