In [1]:
pip install wandb



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import wandb
import numpy as np
from collections import OrderedDict
import time

# Initialize Wandb
wandb.init(
    project="cifar10-cnn-lab2",
    name="resnet18-cifar10-25epochs",
    config={
        "architecture": "ResNet18",
        "dataset": "CIFAR-10",
        "epochs": 25,
        "batch_size": 128,
        "learning_rate": 0.001,
        "optimizer": "Adam"
    }
)

config = wandb.config


# Custom CIFAR-10 Dataset Class
class CustomCIFAR10Dataset(Dataset):
    """Custom Dataset wrapper for CIFAR-10"""

    def __init__(self, root='./data', train=True, transform=None, download=True):
        self.cifar_data = torchvision.datasets.CIFAR10(
            root=root,
            train=train,
            download=download,
            transform=None
        )
        self.transform = transform

    def __len__(self):
        return len(self.cifar_data)

    def __getitem__(self, idx):
        image, label = self.cifar_data[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


# Define CNN Model (ResNet18)
class ResNet18CIFAR(nn.Module):
    """ResNet18 adapted for CIFAR-10"""

    def __init__(self, num_classes=10):
        super(ResNet18CIFAR, self).__init__()
        # Load pretrained ResNet18 and modify for CIFAR-10
        self.model = torchvision.models.resnet18(weights=None)
        # Modify first conv layer for 32x32 images
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()  # Remove maxpool for small images
        # Modify final layer for 10 classes
        self.model.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        return self.model(x)


# FLOPs Counting Function
def count_flops(model, input_size=(1, 3, 32, 32), device='cpu'):
    """
    Count FLOPs for the model
    """
    def conv_flops_counter_hook(conv_module, input, output):
        batch_size = output.shape[0]
        output_dims = list(output.shape[2:])

        kernel_dims = list(conv_module.kernel_size)
        in_channels = conv_module.in_channels
        out_channels = conv_module.out_channels
        groups = conv_module.groups

        filters_per_channel = out_channels // groups
        conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel

        active_elements_count = batch_size * int(np.prod(output_dims))
        overall_conv_flops = conv_per_position_flops * active_elements_count

        bias_flops = 0
        if conv_module.bias is not None:
            bias_flops = out_channels * active_elements_count

        overall_flops = overall_conv_flops + bias_flops
        conv_module.__flops__ += int(overall_flops)

    def linear_flops_counter_hook(linear_module, input, output):
        batch_size = input[0].shape[0]
        num_flops = batch_size * linear_module.in_features * linear_module.out_features

        if linear_module.bias is not None:
            num_flops += batch_size * linear_module.out_features

        linear_module.__flops__ += int(num_flops)

    def bn_flops_counter_hook(bn_module, input, output):
        batch_size = input[0].shape[0]
        num_elements = input[0].numel() // batch_size
        bn_module.__flops__ += int(2 * num_elements * batch_size)

    def relu_flops_counter_hook(relu_module, input, output):
        batch_size = input[0].shape[0]
        num_elements = input[0].numel() // batch_size
        relu_module.__flops__ += int(num_elements * batch_size)

    model.eval()
    hooks = []

    def add_hooks(m):
        if isinstance(m, nn.Conv2d):
            m.__flops__ = 0
            hooks.append(m.register_forward_hook(conv_flops_counter_hook))
        elif isinstance(m, nn.Linear):
            m.__flops__ = 0
            hooks.append(m.register_forward_hook(linear_flops_counter_hook))
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            m.__flops__ = 0
            hooks.append(m.register_forward_hook(bn_flops_counter_hook))
        elif isinstance(m, (nn.ReLU, nn.ReLU6)):
            m.__flops__ = 0
            hooks.append(m.register_forward_hook(relu_flops_counter_hook))

    model.apply(add_hooks)

    input_tensor = torch.randn(input_size).to(device) # Move input tensor to the correct device
    with torch.no_grad():
        _ = model(input_tensor)

    total_flops = 0
    for m in model.modules():
        if hasattr(m, '__flops__'):
            total_flops += m.__flops__

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return total_flops


# Gradient Flow Tracking
def plot_grad_flow(named_parameters, epoch):
    """
    Track gradient flow through the network
    """
    ave_grads = []
    max_grads = []
    layers = []

    for n, p in named_parameters:
        if p.requires_grad and p.grad is not None and "bias" not in n:
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().cpu().item())
            max_grads.append(p.grad.abs().max().cpu().item())

    # Log to wandb
    grad_data = [[layer, avg, max_val] for layer, avg, max_val in zip(layers, ave_grads, max_grads)]
    table = wandb.Table(data=grad_data, columns=["Layer", "Average Gradient", "Max Gradient"])
    wandb.log(
        {
            f"gradient_flow_epoch_{epoch}": wandb.plot.bar(table, "Layer", "Average Gradient",
                                                        title=f"Gradient Flow - Epoch {epoch}"),
            f"max_gradient_flow_epoch_{epoch}": wandb.plot.bar(table, "Layer", "Max Gradient",
                                                             title=f"Max Gradient Flow - Epoch {epoch}")
        }
    )

    return ave_grads, max_grads


# Weight Update Flow Tracking
def track_weight_updates(model, old_weights, epoch):
    """
    Track weight updates across epochs
    """
    weight_changes = {}

    for name, param in model.named_parameters():
        if name in old_weights and param.requires_grad:
            change = (param.data - old_weights[name]).abs().mean().item()
            weight_changes[name] = change

    # Log to wandb
    wandb.log({f"weight_update/{name}": change for name, change in weight_changes.items()})

    return weight_changes


# Training Function
def train_epoch(model, trainloader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if i % 50 == 49:
            wandb.log(
                {
                    "batch_loss": loss.item(),
                    "batch_accuracy": 100. * correct / total
                }
            )

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc


# Validation Function
def validate(model, testloader, criterion, device):
    """Validate the model"""
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    test_loss = test_loss / len(testloader)
    test_acc = 100. * correct / total

    return test_loss, test_acc


# Main Training Loop
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Data transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Create custom datasets
    print("Loading CIFAR-10 dataset...")
    trainset = CustomCIFAR10Dataset(root='./data', train=True, transform=transform_train, download=True)
    testset = CustomCIFAR10Dataset(root='./data', train=False, transform=transform_test, download=True)

    # Create dataloaders
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

    print(f"Training samples: {len(trainset)}, Test samples: {len(testset)}")

    # Initialize model
    print("Initializing ResNet18 model...")
    model = ResNet18CIFAR(num_classes=10).to(device)

    # Count FLOPs
    print("Counting FLOPs...")
    flops = count_flops(model, input_size=(1, 3, 32, 32), device=device)
    print(f"Total FLOPs: {flops:,} ({flops/1e6:.2f} MFLOPs)")
    wandb.config.update({"total_flops": flops, "flops_millions": flops/1e6})

    # Watch model with wandb
    wandb.watch(model, log="all", log_freq=100)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25)

    # Training loop
    print(f"\nStarting training for {config.epochs} epochs...")
    best_acc = 0.0

    for epoch in range(config.epochs):
        start_time = time.time()

        # Save old weights for tracking updates
        old_weights = {name: param.data.clone() for name, param in model.named_parameters()}

        # Train
        train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device, epoch)

        # Validate
        test_loss, test_acc = validate(model, testloader, criterion, device)

        # Track gradient flow every 5 epochs
        if epoch % 5 == 0:
            plot_grad_flow(model.named_parameters(), epoch)

        # Track weight updates
        weight_changes = track_weight_updates(model, old_weights, epoch)

        # Learning rate step
        scheduler.step()

        epoch_time = time.time() - start_time

        # Log metrics
        wandb.log(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "test_loss": test_loss,
                "test_accuracy": test_acc,
                "learning_rate": optimizer.param_groups[0]['lr'],
                "epoch_time": epoch_time
            }
        )

        print(f"Epoch [{epoch+1}/{config.epochs}] | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | "
              f"Time: {epoch_time:.2f}s")

        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            print(f"New best accuracy: {best_acc:.2f}%")

    print(f"\nTraining completed! Best Test Accuracy: {best_acc:.2f}%")
    wandb.log({"best_test_accuracy": best_acc})

    # Finish wandb run
    wandb.finish()

    print("\nAll visualizations have been logged to Weights & Biases!")
    print("Check your wandb dashboard for gradient flow and weight update visualizations.")


if __name__ == "__main__":
    main()

Using device: cuda
Loading CIFAR-10 dataset...
Training samples: 50000, Test samples: 10000
Initializing ResNet18 model...
Counting FLOPs...
Total FLOPs: 557,208,586 (557.21 MFLOPs)

Starting training for 25 epochs...
Epoch [1/25] | Train Loss: 1.3352 | Train Acc: 51.28% | Test Loss: 1.2315 | Test Acc: 60.58% | Time: 46.94s
New best accuracy: 60.58%
Epoch [2/25] | Train Loss: 0.8470 | Train Acc: 70.20% | Test Loss: 0.8698 | Test Acc: 70.44% | Time: 44.86s
New best accuracy: 70.44%
Epoch [3/25] | Train Loss: 0.6632 | Train Acc: 76.97% | Test Loss: 0.7372 | Test Acc: 75.05% | Time: 45.75s
New best accuracy: 75.05%
Epoch [4/25] | Train Loss: 0.5512 | Train Acc: 80.85% | Test Loss: 0.5848 | Test Acc: 80.42% | Time: 44.40s
New best accuracy: 80.42%
Epoch [5/25] | Train Loss: 0.4826 | Train Acc: 83.34% | Test Loss: 0.5903 | Test Acc: 80.34% | Time: 44.59s
Epoch [6/25] | Train Loss: 0.4234 | Train Acc: 85.33% | Test Loss: 0.5243 | Test Acc: 82.71% | Time: 45.74s
New best accuracy: 82.71%
Epoc

0,1
batch_accuracy,▁▁▄▄▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████████████████
batch_loss,█▆▅▆▅▅▄▅▄▃▃▄▃▃▃▃▃▃▂▂▂▃▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
best_test_accuracy,▁
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
epoch_time,█▃▅▁▂▅▁▄▃▂▃▃▂▂▃▃▄▂▂▅▂▂▄▂▁
learning_rate,████▇▇▇▆▆▆▅▅▄▄▃▃▃▂▂▂▁▁▁▁▁
test_accuracy,▁▃▄▅▅▆▆▇▇▇▇▇▇▇███████████
test_loss,█▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▅▅▆▆▆▆▇▇▇▇▇▇▇██████████
train_loss,█▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
batch_accuracy,98.88393
batch_loss,0.01224
best_test_accuracy,91.97
epoch,24
epoch_time,44.27794
learning_rate,0
test_accuracy,91.9
test_loss,0.32611
train_accuracy,98.856
train_loss,0.03501



All visualizations have been logged to Weights & Biases!
Check your wandb dashboard for gradient flow and weight update visualizations.
