<a href="https://colab.research.google.com/github/Kazi-Rakib-Hasan-Jawwad/Histo-FSL/blob/master/BDCSPN_BT_n1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from PIL import Image
from torchvision import transforms
import pytorch_lightning as pl
from collections import defaultdict
from copy import deepcopy
import os
import random
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
def load_encoder_weights(encoder, weights):
    model_dict = encoder.state_dict()
    weights = {k: v for k, v in weights.items() if k in model_dict}
    if weights == {}:
        print('No weight could be loaded..')
    model_dict.update(weights)
    encoder.load_state_dict(model_dict)
    return encoder

In [3]:
def print_trainable_parameters(model: torch.nn) -> None:
    """Print number of trainable parameters."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param}"
        f" || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [4]:
# Load model

encoder = torchvision.models.__dict__['resnet50'](pretrained=False)
path = "/home/rakib/models/paper_benchmarking_ssl_diverse_pathology/bt_rn50_ep200.torch"
state_dict = torch.load(path, map_location='cuda:0')
# state_dict = state['state_dict']
encoder.fc = torch.nn.Identity()
model = load_encoder_weights(encoder, state_dict)



In [62]:
# Set the base network to non-trainable (for speedup fine-tuning):

for param in model.parameters():
    param.requires_grad = False

In [6]:
# Attach a trainable linear layer to adapt to the new task:

num_classes = 9  # Number of classes in the tuning dataset
model.fc = torch.nn.Linear(2048, num_classes)
#torch.nn.init.xavier_uniform_(model.fc.weight)

In [7]:
from easyfsl.methods import BDCSPN

net = BDCSPN(model)

In [8]:
print_trainable_parameters(net)

trainable params: 18441 || all params: 23526473 || trainable%: 0.08


In [9]:
from easyfsl.samplers import TaskSampler
from easyfsl.datasets import FeaturesDataset, WrapFewShotDataset

In [10]:
# Load data

# Instantiate the datasets
train_path = Path("/home/rakib/data/NCT-CRC-Modified-81K/train/")
val_path = Path("/home/rakib/data/NCT-CRC-Modified-81K/val/")

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
                                    transforms.ToTensor()])
val_transform = transforms.Compose([transforms.CenterCrop((224, 224)), transforms.Resize((224, 224)), transforms.ToTensor()])

train_data = torchvision.datasets.ImageFolder(train_path, transform=train_transform)
val_data = torchvision.datasets.ImageFolder(val_path, transform=val_transform)

In [11]:
n_way = 5
n_shot = 6
n_query = 10

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_workers = 12
n_tasks_per_epoch = 100
n_validation_tasks = 100

In [12]:
train_set = WrapFewShotDataset(train_data)
val_set = WrapFewShotDataset(val_data)

# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape
train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

# Finally, the DataLoader. We customize the collate_fn so that batches are delivered
# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)


Scrolling dataset's labels...: 100%|█████| 63000/63000 [02:20<00:00, 449.28it/s]
Scrolling dataset's labels...: 100%|█████| 15840/15840 [00:19<00:00, 799.05it/s]


In [13]:
net.to(DEVICE)

BDCSPN(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          

In [20]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 50
#scheduler_milestones = [120, 160]
#scheduler_gamma = 0.1
#learning_rate = 1e-2
tb_logs_dir = Path(".")
'''
train_optimizer = SGD(
    net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))
'''
train_optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=train_optimizer, T_max=20)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

In [17]:
import copy
from easyfsl.utils import evaluate
from tqdm import tqdm
from statistics import mean
from easyfsl.methods import FewShotClassifier

In [21]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [22]:
best_state = net.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(net, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        net, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(net.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler.step()


Epoch 0


Training: 100%|████████████████████| 100/100 [00:08<00:00, 12.16it/s, loss=0.93]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.51it/s, accuracy=0.956]

Ding ding ding! We found a new best model!
Epoch 1



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.31it/s, loss=0.892]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.65it/s, accuracy=0.966]


Ding ding ding! We found a new best model!
Epoch 2


Training: 100%|███████████████████| 100/100 [00:07<00:00, 12.54it/s, loss=0.888]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 14.11it/s, accuracy=0.967]


Ding ding ding! We found a new best model!
Epoch 3


Training: 100%|███████████████████| 100/100 [00:07<00:00, 12.80it/s, loss=0.867]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.75it/s, accuracy=0.974]

Ding ding ding! We found a new best model!
Epoch 4



Training: 100%|███████████████████| 100/100 [00:07<00:00, 12.88it/s, loss=0.868]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.00it/s, accuracy=0.977]

Ding ding ding! We found a new best model!
Epoch 5



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.25it/s, loss=0.866]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 14.19it/s, accuracy=0.978]


Ding ding ding! We found a new best model!
Epoch 6


Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.62it/s, loss=0.857]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 13.18it/s, accuracy=0.98]


Ding ding ding! We found a new best model!
Epoch 7


Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.25it/s, loss=0.857]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 13.18it/s, accuracy=0.98]

Epoch 8



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.62it/s, loss=0.862]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.07it/s, accuracy=0.983]

Ding ding ding! We found a new best model!
Epoch 9



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.85it/s, loss=0.857]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.94it/s, accuracy=0.981]

Epoch 10



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.89it/s, loss=0.851]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.19it/s, accuracy=0.979]

Epoch 11



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.97it/s, loss=0.85]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.96it/s, accuracy=0.978]

Epoch 12



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.87it/s, loss=0.848]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.31it/s, accuracy=0.983]

Ding ding ding! We found a new best model!
Epoch 13



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.94it/s, loss=0.846]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.18it/s, accuracy=0.984]

Ding ding ding! We found a new best model!
Epoch 14



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.72it/s, loss=0.849]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.05it/s, accuracy=0.981]

Epoch 15



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.43it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.16it/s, accuracy=0.983]

Epoch 16



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.38it/s, loss=0.846]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.38it/s, accuracy=0.982]

Epoch 17



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.15it/s, loss=0.847]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.13it/s, accuracy=0.984]

Epoch 18



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.87it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.08it/s, accuracy=0.983]

Epoch 19



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.86it/s, loss=0.845]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 12.97it/s, accuracy=0.98]

Epoch 20



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.98it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.70it/s, accuracy=0.981]

Epoch 21



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.74it/s, loss=0.847]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.13it/s, accuracy=0.982]

Epoch 22



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.99it/s, loss=0.842]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.91it/s, accuracy=0.977]

Epoch 23



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.90it/s, loss=0.85]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.92it/s, accuracy=0.984]

Ding ding ding! We found a new best model!
Epoch 24



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.84it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.92it/s, accuracy=0.984]

Epoch 25



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.64it/s, loss=0.842]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.23it/s, accuracy=0.987]

Ding ding ding! We found a new best model!
Epoch 26



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.76it/s, loss=0.845]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.16it/s, accuracy=0.981]

Epoch 27



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.83it/s, loss=0.848]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.95it/s, accuracy=0.985]

Epoch 28



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.69it/s, loss=0.841]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.08it/s, accuracy=0.984]

Epoch 29



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.04it/s, loss=0.844]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 13.15it/s, accuracy=0.98]

Epoch 30



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.13it/s, loss=0.845]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.17it/s, accuracy=0.984]

Epoch 31



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.11it/s, loss=0.842]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.23it/s, accuracy=0.986]

Epoch 32



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.07it/s, loss=0.843]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.13it/s, accuracy=0.983]

Epoch 33



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.13it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.24it/s, accuracy=0.981]

Epoch 34



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.84it/s, loss=0.843]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.07it/s, accuracy=0.979]

Epoch 35



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.99it/s, loss=0.845]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.02it/s, accuracy=0.981]

Epoch 36



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.79it/s, loss=0.84]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.28it/s, accuracy=0.984]

Epoch 37



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.90it/s, loss=0.841]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.14it/s, accuracy=0.986]

Epoch 38



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.19it/s, loss=0.841]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.09it/s, accuracy=0.987]

Epoch 39



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.91it/s, loss=0.842]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.03it/s, accuracy=0.985]

Epoch 40



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.98it/s, loss=0.839]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.15it/s, accuracy=0.984]

Epoch 41



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.18it/s, loss=0.836]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.06it/s, accuracy=0.982]

Epoch 42



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.23it/s, loss=0.838]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.08it/s, accuracy=0.984]

Epoch 43



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.08it/s, loss=0.837]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.07it/s, accuracy=0.987]

Ding ding ding! We found a new best model!
Epoch 44



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.95it/s, loss=0.836]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.08it/s, accuracy=0.988]


Ding ding ding! We found a new best model!
Epoch 45


Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.99it/s, loss=0.833]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.04it/s, accuracy=0.985]

Epoch 46



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.19it/s, loss=0.838]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.10it/s, accuracy=0.986]

Epoch 47



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.92it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.25it/s, accuracy=0.986]

Epoch 48



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.12it/s, loss=0.835]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.05it/s, accuracy=0.986]

Epoch 49



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.90it/s, loss=0.831]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.14it/s, accuracy=0.982]


In [23]:
net.load_state_dict(best_state)

<All keys matched successfully>

In [24]:
# Instantiate the datasets
test_path = Path("/home/rakib/data/CRC-VAL-HE-7K/")

test_data = torchvision.datasets.ImageFolder(test_path, transform=val_transform)

test_set = WrapFewShotDataset(test_data)

Scrolling dataset's labels...: 100%|███████| 7180/7180 [00:12<00:00, 554.96it/s]


In [39]:
n_test_tasks = 5000

In [40]:
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [41]:
accuracy = evaluate(net, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

100%|███████████████████████| 5000/5000 [05:35<00:00, 14.91it/s, accuracy=0.967]

Average accuracy : 96.66 %





In [45]:
model2 = load_encoder_weights(encoder, state_dict)

In [59]:
dataloader = DataLoader(test_data, batch_size=1, shuffle=False, drop_last=False)
data_labels = np.zeros(shape=(0))
embeddings = np.zeros(shape=(0, 2048))
for x, y in iter(dataloader):
    x = x.to(DEVICE)
    pred = net(x)


In [60]:
# prompt: find shape of pred

print(pred.shape)


torch.Size([1, 5])


In [61]:
print(pred)

tensor([[-0.1396, -0.1177,  0.0492, -0.3430,  0.9620]], device='cuda:0',
       grad_fn=<MmBackward0>)


In [63]:
# Set the base network to non-trainable (for speedup fine-tuning):

for param in model2.parameters():
    param.requires_grad = False

In [64]:
num_classes = 9  # Number of classes in the tuning dataset
model2.fc = torch.nn.Linear(2048, num_classes)

In [65]:
from easyfsl.methods import SimpleShot
net2 = SimpleShot(model2)

In [67]:
net2.to(DEVICE)

SimpleShot(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [71]:
train_optimizer2 = torch.optim.AdamW(model.parameters(), lr=0.001)
train_scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=train_optimizer, T_max=20)

best_state2 = net2.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(net2, train_loader, train_optimizer2)
    validation_accuracy = evaluate(
        net2, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state2 = copy.deepcopy(net2.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    train_scheduler2.step()


Epoch 0


Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.65it/s, loss=1.08]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.74it/s, accuracy=0.952]

Ding ding ding! We found a new best model!
Epoch 1



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.76it/s, loss=0.889]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.65it/s, accuracy=0.963]

Ding ding ding! We found a new best model!
Epoch 2



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.44it/s, loss=0.868]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.36it/s, accuracy=0.978]

Ding ding ding! We found a new best model!
Epoch 3



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.21it/s, loss=0.863]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.43it/s, accuracy=0.978]

Ding ding ding! We found a new best model!
Epoch 4



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.47it/s, loss=0.86]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.41it/s, accuracy=0.979]

Ding ding ding! We found a new best model!
Epoch 5



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.78it/s, loss=0.855]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.49it/s, accuracy=0.981]


Ding ding ding! We found a new best model!
Epoch 6


Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.53it/s, loss=0.857]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.29it/s, accuracy=0.979]

Epoch 7



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.25it/s, loss=0.856]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.39it/s, accuracy=0.981]

Epoch 8



Training: 100%|███████████████████| 100/100 [00:09<00:00, 10.03it/s, loss=0.854]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.01it/s, accuracy=0.984]

Ding ding ding! We found a new best model!
Epoch 9



Training: 100%|████████████████████| 100/100 [00:11<00:00,  8.51it/s, loss=0.85]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.06it/s, accuracy=0.985]

Ding ding ding! We found a new best model!
Epoch 10



Training: 100%|███████████████████| 100/100 [00:11<00:00,  8.78it/s, loss=0.845]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 12.56it/s, accuracy=0.98]

Epoch 11



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.35it/s, loss=0.851]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.23it/s, accuracy=0.985]

Ding ding ding! We found a new best model!
Epoch 12



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.47it/s, loss=0.846]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.42it/s, accuracy=0.985]


Epoch 13


Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.62it/s, loss=0.843]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.53it/s, accuracy=0.984]

Epoch 14



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.50it/s, loss=0.843]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.48it/s, accuracy=0.984]

Epoch 15



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.81it/s, loss=0.842]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.36it/s, accuracy=0.986]

Ding ding ding! We found a new best model!
Epoch 16



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.53it/s, loss=0.846]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.33it/s, accuracy=0.988]

Ding ding ding! We found a new best model!
Epoch 17



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.82it/s, loss=0.84]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.32it/s, accuracy=0.988]

Epoch 18



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.59it/s, loss=0.838]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.48it/s, accuracy=0.983]

Epoch 19



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.56it/s, loss=0.847]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.42it/s, accuracy=0.984]

Epoch 20



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.38it/s, loss=0.836]
Validation: 100%|██████████████| 100/100 [00:08<00:00, 12.28it/s, accuracy=0.98]

Epoch 21



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.54it/s, loss=0.843]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.30it/s, accuracy=0.983]


Epoch 22


Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.30it/s, loss=0.839]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.41it/s, accuracy=0.982]

Epoch 23



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.33it/s, loss=0.839]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.31it/s, accuracy=0.992]

Ding ding ding! We found a new best model!
Epoch 24



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.71it/s, loss=0.844]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.21it/s, accuracy=0.986]

Epoch 25



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.58it/s, loss=0.839]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.09it/s, accuracy=0.989]

Epoch 26



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.41it/s, loss=0.841]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.42it/s, accuracy=0.984]

Epoch 27



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.39it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.59it/s, accuracy=0.986]

Epoch 28



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.47it/s, loss=0.837]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.93it/s, accuracy=0.989]

Epoch 29



Training: 100%|████████████████████| 100/100 [00:08<00:00, 11.94it/s, loss=0.83]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.48it/s, accuracy=0.987]

Epoch 30



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.03it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.23it/s, accuracy=0.986]

Epoch 31



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.23it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.04it/s, accuracy=0.984]

Epoch 32



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.31it/s, loss=0.836]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.26it/s, accuracy=0.989]

Epoch 33



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.03it/s, loss=0.832]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.35it/s, accuracy=0.985]

Epoch 34



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.21it/s, loss=0.831]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.27it/s, accuracy=0.985]

Epoch 35



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.46it/s, loss=0.837]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.99it/s, accuracy=0.987]

Epoch 36



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.27it/s, loss=0.836]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.06it/s, accuracy=0.989]

Epoch 37



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.37it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.61it/s, accuracy=0.987]

Epoch 38



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.75it/s, loss=0.832]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.06it/s, accuracy=0.986]

Epoch 39



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.90it/s, loss=0.836]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.76it/s, accuracy=0.986]

Epoch 40



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.39it/s, loss=0.835]
Validation: 100%|██████████████| 100/100 [00:07<00:00, 12.85it/s, accuracy=0.99]

Epoch 41



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.14it/s, loss=0.827]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.02it/s, accuracy=0.986]

Epoch 42



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.20it/s, loss=0.837]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.27it/s, accuracy=0.988]

Epoch 43



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.88it/s, loss=0.835]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.92it/s, accuracy=0.987]

Epoch 44



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.19it/s, loss=0.833]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.94it/s, accuracy=0.988]

Epoch 45



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.42it/s, loss=0.833]
Validation: 100%|█████████████| 100/100 [00:08<00:00, 12.41it/s, accuracy=0.988]

Epoch 46



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.69it/s, loss=0.834]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.56it/s, accuracy=0.989]

Epoch 47



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.14it/s, loss=0.833]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.71it/s, accuracy=0.987]

Epoch 48



Training: 100%|███████████████████| 100/100 [00:08<00:00, 11.62it/s, loss=0.828]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 12.74it/s, accuracy=0.986]

Epoch 49



Training: 100%|███████████████████| 100/100 [00:08<00:00, 12.31it/s, loss=0.833]
Validation: 100%|█████████████| 100/100 [00:07<00:00, 13.33it/s, accuracy=0.988]


In [72]:
accuracy = evaluate(net2, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

100%|███████████████████████| 5000/5000 [05:28<00:00, 15.22it/s, accuracy=0.961]

Average accuracy : 96.06 %



