In [1]:
import matplotlib.pyplot as plt
import torch
import shutil
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import numpy as np
import wandb
import time

In [2]:
transform = transforms.Compose([
    transforms.ToTensor()
])

# Datasets
trainset = torchvision.datasets.FashionMNIST(root='.', train=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='.', train=False, transform=transform)

# Dataloaders to feed the data in batches
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1000, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

In [3]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mxavierohan[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
class Network(nn.Module):
    
    def __init__(self):
        super(Network, self).__init__()
    
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

        self.fc1 = nn.Linear(12 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = x.reshape(-1, 12 * 4 * 4)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = F.relu(x)
        
        x = self.fc3(x)
        return x

In [5]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
def train(network, epoch, criterion, optimizer, trainloader):
    network.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = network(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    total_loss = running_loss / len(trainloader)
    accuracy = 100 * correct / total
    return total_loss, accuracy

def validate(network, epoch, criterion, testloader):
    network.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(testloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = network(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    total_loss = running_loss / len(testloader)
    accuracy = 100 * correct / total
    return total_loss, accuracy

In [9]:
num_epochs = 5
num_runs = 3

import os
os.environ["WANDB_SILENT"] = "true"

for run in range(num_runs):

    # Set different seeds for each run
    torch.manual_seed(run)

    # Initialize a new wandb run
    wandb.init(name=f'run_{run}', project="ds598", group="experiment_1", job_type="run_{}".format(run+1),)

    # optional
    wandb.config.lr = 0.01
    
    network = Network().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(network.parameters(), wandb.config.lr)

    # Log the network weight histograms (optional)
    wandb.watch(network)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    start_time = time.time()
    for epoch in range(1, num_epochs+1):
        loss_train, acc_train = train(network, epoch, criterion, optimizer, trainloader)
        loss_valid, acc_valid = validate(network, epoch, criterion, testloader)
        
        # Log metrics to wandb
        wandb.log({
            "Epoch": epoch,
            "Train Loss": loss_train,
            "Train Acc": acc_train,
            "Valid Loss": loss_valid,
            "Valid Acc": acc_valid
        })
    
    print("Time Elapsed : {:.4f}s".format(time.time() - start_time))
    
    # Finish the current run before starting the next
    wandb.finish()

Time Elapsed : 38.5441s




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113776463187404, max=1.0…

Time Elapsed : 40.0877s




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113681509676907, max=1.0…

Time Elapsed : 39.8059s




In [10]:
api = wandb.Api()
project_name = "ds598"
runs = api.runs("xavierohan/" + project_name, {"group": "experiment_1"})

valid_accuracies = []

for run in runs:
    # Assuming you want to analyze the validation accuracy at the last epoch
    # If you want to do this for every epoch, you would collect all epochs' accuracies
    history = run.scan_history(keys=["Valid Acc"])
    valid_acc = [x["Valid Acc"] for x in history]
    if valid_acc:
        # Taking the last epoch's accuracy
        valid_accuracies.append(valid_acc[-1])

if valid_accuracies:
    mean_acc = np.mean(valid_accuracies)
    std_acc = np.std(valid_accuracies)
    print(f'Validation Accuracy across runs: {mean_acc:.2f} ± {std_acc:.2f}')
else:
    print("No validation accuracies found. Please check your wandb setup and project name.")

Validation Accuracy across runs: 86.31 ± 0.30
