# How to save all your trained model weights locally after every epoch
> This notebook provides working code for a Checkpoint Saver for the report - [How to save all your trained model weights locally after every epoch](https://wandb.ai/amanarora/melanoma/reports/How-to-save-all-your-trained-model-weights-locally-after-every-epoch--VmlldzoxNTkzNjY1).

## Download the Imagenette dataset
> Uncomment the first time when you are running this notebook.

In [1]:
# get dataset
# !mkdir data && cd data 
# !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
# !tar -xvf imagenette2-160.tgz

## Imports

In [2]:
import os
import torch
import torchvision
import timm
import torch.nn as nn
from tqdm.notebook import tqdm
import albumentations
from torchvision import transforms
import numpy as np 
import os
import wandb 

# set logging
import logging
logging.getLogger().setLevel(logging.INFO)

## Config

In [3]:
IMG_SIZE = 160 
MODEL_NAME = "resnet34"
LR = 1e-4
EPOCHS = 5

In [4]:
train_aug = transforms.Compose(
        [
            transforms.RandomCrop(IMG_SIZE),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
)

In [5]:
val_aug = transforms.Compose(
    [
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

Below, we initialize the Weights and Biases experiment by passing in the config values. 

In [6]:
run = wandb.init(project="melanoma-artifact", config={
    'image size': IMG_SIZE, 
    'model name': MODEL_NAME, 
    'learning rate': LR, 
    'epochs': EPOCHS, 
    'training augmentation': train_aug, 
    'valid augmentation': val_aug
})

[34m[1mwandb[0m: Currently logged in as: [33mamanarora[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [7]:
train_aug, val_aug

(Compose(
     RandomCrop(size=(160, 160), padding=None)
     RandomHorizontalFlip(p=0.5)
     ToTensor()
     Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
 ),
 Compose(
     CenterCrop(size=(160, 160))
     ToTensor()
     Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
 ))

## Checkpoint Saver with W&B artifacts integration
> For a complete explanation of how the below code works, please refer to the report - [How to save all your trained model weights locally after every epoch](https://wandb.ai/amanarora/melanoma/reports/How-to-save-all-your-trained-model-weights-locally-after-every-epoch--VmlldzoxNTkzNjY1).

In [8]:
class CheckpointSaver:
    def __init__(self, dirpath, decreasing=True, top_n=5):
        """
        dirpath: Directory path where to store all model weights 
        decreasing: If decreasing is `True`, then lower metric is better
        top_n: Total number of models to track based on validation metric value
        """
        if not os.path.exists(dirpath): os.makedirs(dirpath)
        self.dirpath = dirpath
        self.top_n = top_n 
        self.decreasing = decreasing
        self.top_model_paths = []
        self.best_metric_val = np.Inf if decreasing else -np.Inf
        
    def __call__(self, model, epoch, metric_val):
        model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
        save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val
        if save: 
            logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
            self.best_metric_val = metric_val
            torch.save(model.state_dict(), model_path)
            self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
            self.top_model_paths.append({'path': model_path, 'score': metric_val})
            self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
        if len(self.top_model_paths)>self.top_n: 
            self.cleanup()
    
    def log_artifact(self, filename, model_path, metric_val):
        artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
        artifact.add_file(model_path)
        wandb.run.log_artifact(artifact)        
    
    def cleanup(self):
        to_remove = self.top_model_paths[self.top_n:]
        logging.info(f"Removing extra models.. {to_remove}")
        for o in to_remove:
            os.remove(o['path'])
        self.top_model_paths = self.top_model_paths[:self.top_n]

## Model training

In [9]:
def train_fn(model, train_data_loader, optimizer, epoch, device='cuda'):
    model.train()
    fin_loss = 0.0
    tk = tqdm(train_data_loader, desc="Epoch" + " [TRAIN] " + str(epoch + 1))

    for t, data in enumerate(tk):
        data[0] = data[0].to(device)
        data[1] = data[1].to(device)

        optimizer.zero_grad()
        out = model(data[0])
        loss = nn.CrossEntropyLoss()(out, data[1])
        loss.backward()
        optimizer.step()

        fin_loss += loss.item()
        tk.set_postfix(
            {
                "loss": "%.6f" % float(fin_loss / (t + 1)),
                "LR": optimizer.param_groups[0]["lr"],
            }
        )
    return fin_loss / len(train_data_loader), optimizer.param_groups[0]["lr"]

In [10]:
def eval_fn(model, eval_data_loader, epoch, device='cuda'):
    model.eval()
    fin_loss = 0.0
    tk = tqdm(eval_data_loader, desc="Epoch" + " [VALID] " + str(epoch + 1))

    with torch.no_grad():
        for t, data in enumerate(tk):
            data[0] = data[0].to(device)
            data[1] = data[1].to(device)
            out = model(data[0])
            loss = nn.CrossEntropyLoss()(out, data[1])
            fin_loss += loss.item()
            tk.set_postfix({"loss": "%.6f" % float(fin_loss / (t + 1))})
        return fin_loss / len(eval_data_loader)

In [11]:
def train(train_dir, test_dir):
    train_dataset = torchvision.datasets.ImageFolder(
        train_dir, transform=train_aug
    )
    eval_dataset = torchvision.datasets.ImageFolder(
        test_dir, transform=val_aug
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=4
    )
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=64, num_workers=4
    )

    # model
    model = timm.create_model(MODEL_NAME, pretrained=True)
    model = model.cuda()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # checkpoint saver
    checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=5)
    for epoch in range(EPOCHS):
        avg_loss_train, lr = train_fn(
            model, train_dataloader, optimizer, epoch, device='cuda'
        )
        avg_loss_eval = eval_fn(model, eval_dataloader, epoch, device='cuda')
        checkpoint_saver(model, epoch, avg_loss_eval)
        wandb.run.log({'epoch': epoch, 'train loss': avg_loss_train, 'eval loss': avg_loss_eval})
        print(
            f"EPOCH = {epoch} | TRAIN_LOSS = {avg_loss_train} | EVAL_LOSS = {avg_loss_eval}"
        )

In [12]:
train(train_dir='./data/imagenette2-160/train/', test_dir='./data/imagenette2-160/val/')

Epoch [TRAIN] 1:   0%|          | 0/74 [00:00<?, ?it/s]

Epoch [VALID] 1:   0%|          | 0/62 [00:00<?, ?it/s]

INFO:root:Current metric value better than 0.19492560036240086 better than best inf, saving model at ./model_weights/ResNet_epoch0.pt, & logging model weights to W&B.


EPOCH = 0 | TRAIN_LOSS = 1.375664096527003 | EVAL_LOSS = 0.19492560036240086


Epoch [TRAIN] 2:   0%|          | 0/74 [00:00<?, ?it/s]

Epoch [VALID] 2:   0%|          | 0/62 [00:00<?, ?it/s]

INFO:root:Current metric value better than 0.1360785181619107 better than best 0.19492560036240086, saving model at ./model_weights/ResNet_epoch1.pt, & logging model weights to W&B.


EPOCH = 1 | TRAIN_LOSS = 0.11170286700330875 | EVAL_LOSS = 0.1360785181619107


Epoch [TRAIN] 3:   0%|          | 0/74 [00:00<?, ?it/s]

Epoch [VALID] 3:   0%|          | 0/62 [00:00<?, ?it/s]

EPOCH = 2 | TRAIN_LOSS = 0.04712816304527223 | EVAL_LOSS = 0.1487596299078676


Epoch [TRAIN] 4:   0%|          | 0/74 [00:00<?, ?it/s]

Epoch [VALID] 4:   0%|          | 0/62 [00:00<?, ?it/s]

INFO:root:Current metric value better than 0.12546487889748306 better than best 0.1360785181619107, saving model at ./model_weights/ResNet_epoch3.pt, & logging model weights to W&B.


EPOCH = 3 | TRAIN_LOSS = 0.034359857650800935 | EVAL_LOSS = 0.12546487889748306


Epoch [TRAIN] 5:   0%|          | 0/74 [00:00<?, ?it/s]

Epoch [VALID] 5:   0%|          | 0/62 [00:00<?, ?it/s]

EPOCH = 4 | TRAIN_LOSS = 0.027094046788202045 | EVAL_LOSS = 0.13780122134278738
