## Introduction: Logging using Weights & Biases

ITU KSADMAL1KU-NLP - Advanced Machine Learning for NLP in KCS 2024

by Bertram Højer, Stefan Heinrich, Christian H. Rasmussen, & material by Kevin Murphy.

All info and static material: https://learnit.itu.dk/course/view.php?id=3024579

-------------------------------------------------------------------------------

In this notebook we reuse some code for training a simple model on MNIST from week 2 for a live coding session. 
We first simply log the training loss by printing the values, and then add additional weights&biases logging metrics.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import os
import wandb

from types import SimpleNamespace

# Part 1

In [None]:
config = SimpleNamespace(
    learning_rate=0.01,
    momentum=0.9,
    epochs=5,
    batch_size=32
)

In [3]:
# transformations
transform = transforms.Compose([transforms.ToTensor()])

# Create a dataloader for Pytorch training
# download and load training dataset
trainset = torchvision.datasets.MNIST(root='../data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                          shuffle=True)

# download and load testing dataset
testset = torchvision.datasets.MNIST(root='../data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False)

img_size = trainloader.dataset.data.shape[1]
class_out = trainloader.dataset.targets.unique().size()[0]

In [4]:
# specify the model class
class Model(nn.Module):

  def __init__(self, img_size, fc1_out, fc2_out, class_out):
      super(Model, self).__init__()

      self.fc1 = nn.Linear(in_features=img_size*img_size, out_features=fc1_out)
      self.fc2 = nn.Linear(in_features=fc1_out, out_features=fc2_out)
      self.output_layer = nn.Linear(in_features=fc2_out, out_features=class_out)

  def forward(self, img):

    # we flatten the 2D image into one long array
    img = img.flatten(start_dim=1)

    x = self.fc1(img)
    x = F.relu(x)
    x = self.fc2(x)
    x = F.relu(x)

    x = self.output_layer(x)

    return x

In [5]:
config = SimpleNamespace(
    learning_rate=0.01,
    momentum=0.9,
    epochs=5,
    batch_size=32
)

img_size = trainloader.dataset.data.shape[1]
class_out = trainloader.dataset.targets.unique().size()[0]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cuda:mps" if torch.cuda.is_available() and torch.cuda.get_device_properties(0).is_multi_gpu else "cpu")

In [None]:
model = Model(img_size, 128, 128, class_out)
model = model.to(device)
model

In [7]:
# training specifics
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, momentum=config.momentum)

In [8]:
# Training loop
def get_accuracy(logit, target, batch_size):
    # compute accuracy
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

def train_model(model, config, trainloader, criterion, optimizer):

    for epoch in range(config.epochs):
        # Put the model in training mode
        model = model.train()
        
        train_running_loss = 0.0
        train_acc = 0.0

        for idx, (images, labels) in enumerate(trainloader):

            images = images.to(device)
            labels = labels.to(device)

            # loss and optimiser definitions!
            logits = model(images)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
        
            loss.backward()
        
            # update model params
            optimizer.step()
        
            train_running_loss += loss.detach().item()
            train_acc += get_accuracy(logits, labels, config.batch_size)

        print('Epoch: %d | Loss: %.4f | Train Accuracy: %.2f' \
            %(epoch, train_running_loss / idx, train_acc/idx))
    
    return model, train_running_loss, train_acc


def eval_model(model, testloader):

    model = model.eval()
    test_acc = 0.0
    for i, (images, labels) in enumerate(testloader, 0):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_acc += get_accuracy(outputs, labels, 1)

    print(f"Test Accuracy: {test_acc/i}")

In [None]:
# run the training loop
model, _, _ = train_model(model, config, trainloader, criterion, optimizer)

# After training your model, save the state_dict
torch.save(model.state_dict(), "model_weights.pth")

In [None]:
eval_model(model, testloader)

In [None]:
# After training your model, save the state_dict
torch.save(model.state_dict(), "model_weights.pth")

# Part 2

In [None]:
# Initialize W&B run
run = wandb.init(project="aml-introduction")
wandb.config.update(config)

In [12]:
# transformations
transform = transforms.Compose(
    [transforms.ToTensor()])

# Create a dataloader for Pytorch training
# download and load training dataset
trainset = torchvision.datasets.MNIST(root='../data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                          shuffle=True)

# download and load testing dataset
testset = torchvision.datasets.MNIST(root='../data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False)

In [None]:
model_wnb = Model(img_size, 128, 128, class_out)
model_wnb = model_wnb.to(device)
model_wnb

In [15]:
# training specifics
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_wnb.parameters(), lr=config.learning_rate, momentum=config.momentum)

In [17]:
def train_model_wnb(model, config, trainloader, criterion, optimizer):

    for epoch in range(config.epochs):
        # Put the model in training mode
        model = model.train()
        
        train_running_loss = 0.0
        train_acc = 0.0

        for idx, (images, labels) in enumerate(trainloader):

            images = images.to(device)
            labels = labels.to(device)

            # loss and optimiser definitions!
            logits = model(images)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
        
            loss.backward()
        
            # update model params
            optimizer.step()
        
            train_running_loss += loss.detach().item()
            train_acc += get_accuracy(logits, labels, config.batch_size)

        wandb.log({'Train Loss': train_running_loss / idx, 'Train Accuracy': train_acc/idx})
    
    return model, train_running_loss, train_acc

In [None]:
wandb.watch(model, log="all")

# Train the W&B model
model_wnb, _, _ = train_model_wnb(model_wnb, config, trainloader, criterion, optimizer)

# Saving and reloading trained models

### Save the model state as an artefact

In [None]:
# After training your model, save the state_dict
torch.save(model_wnb.state_dict(), "model_weights.pth")

# Log the model as an artifact
artifact = wandb.Artifact('model', type='model')
artifact.add_file("model_weights.pth")
wandb.log_artifact(artifact)

# Finish the run
wandb.finish()

### Load the model state into new model

In [None]:
# entity: your entity, your username of the name of your team
# project-name: your project name
# model-version: the version of the model you want to download
entity = ''
project_name = ''
model_version = ''

run = wandb.init()
artifact = run.use_artifact(f'{entity}/{project_name}/{model_version}', type='model')
artifact_dir = artifact.download()

In [21]:
# Assuming the artifact contains a single file and it's the .pth file you want
# If there are multiple files, you need to know the exact file name
model_files = os.listdir(artifact_dir)
model_file = [f for f in model_files if f.endswith('.pth')][0]
model_path = os.path.join(artifact_dir, model_file)

In [None]:
# reinstatiate the model and load the weights
model_wnb = Model(img_size, 128, 128, class_out)
model_wnb.load_state_dict(torch.load(model_path))

In [None]:
eval_model(model, testloader)
wandb.finish()