# Fine-Tuning with PyTorch

Throughout this notebook, we use the CIFAR-100 dataset, a popular computer-vision dataset of 60,000 32x32 color images to be classified in one of 100 classes, with 600 images per class.

In [38]:
#from random import random
import os
import pandas as pd

## Datasets and DataLoaders <a name='data'></a>

As mentioned above, we rely on the `ImageDataset` class of PyTorch to create the required datasets for training, validation, training+validation and testing. Out of each dataset, we then create a DataLoader to be used in the training/evaluation loops to efficiently fetch images in batches from disk.

We perform an initial step to load in the train data and compute the mean and standard deviation of the dataset for each channel (R, G, B), across all images and all pixels. We compute a mean and stdev value batch-by-batch to avoid loading the entire dataset in memory, and then compute the mean of the means and of the stdevs.  
**NOTE**: if you have enough RAM (or memory on the GPU), you can use a batch_size equal to the entire train_dataset length, it will provide a more accurate estimation of the means and stdevs by channel.

In [39]:
#!pip install -q --upgrade torchvision

In [40]:
import torchvision
import torch

In [41]:
import sklearn
import copy

In [42]:
#import os
import time
import math
import random
import numpy as np
import pandas as pd
from pathlib import Path
import glob

import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance, ImageOps

from tqdm import tqdm, tqdm_notebook

import torch
from torch import nn, cuda
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision as vision
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD, Optimizer
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, ReduceLROnPlateau

from sklearn.metrics import f1_score

class CIFAR10Policy(object):
    """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
        Example:
        >>> policy = CIFAR10Policy()
        >>> transformed = policy(image)
        Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     CIFAR10Policy(),
        >>>     transforms.ToTensor()])
    """
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
            SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
            SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
            SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),

            SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
            SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
            SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
            SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),

            SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
            SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
            SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
            SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
            SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),

            SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
            SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
            SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
            SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
            SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),

            SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
            SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
        ]


    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment CIFAR10 Policy"


class SubPolicy(object):
    def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
        ranges = {
            "shearX": np.linspace(0, 0.3, 10),
            "shearY": np.linspace(0, 0.3, 10),
            "translateX": np.linspace(0, 150 / 331, 10),
            "translateY": np.linspace(0, 150 / 331, 10),
            "rotate": np.linspace(0, 30, 10),
            "color": np.linspace(0.0, 0.9, 10),
            "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int),
            "solarize": np.linspace(256, 0, 10),
            "contrast": np.linspace(0.0, 0.9, 10),
            "sharpness": np.linspace(0.0, 0.9, 10),
            "brightness": np.linspace(0.0, 0.9, 10),
            "autocontrast": [0] * 10,
            "equalize": [0] * 10,
            "invert": [0] * 10
        }

        # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
        def rotate_with_fill(img, magnitude):
            rot = img.convert("RGBA").rotate(magnitude)
            return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)

        func = {
            "shearX": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            "shearY": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            "translateX": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
                fillcolor=fillcolor),
            "translateY": lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
                fillcolor=fillcolor),
            "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
            # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
            "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
            "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
            "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
            "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
            "equalize": lambda img, magnitude: ImageOps.equalize(img),
            "invert": lambda img, magnitude: ImageOps.invert(img)
        }

        # self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
        #     operation1, ranges[operation1][magnitude_idx1],
        #     operation2, ranges[operation2][magnitude_idx2])
        self.p1 = p1
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]
        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]


    def __call__(self, img):
        if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
        return img


class TestDataset(Dataset):
    def __init__(self, df, mode='test', transforms=None):
        self.df = df
        self.mode = mode
        self.transform = transforms[self.mode]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        image = Image.open(TEST_IMAGE_PATH / self.df[idx]).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image

In [43]:
mean = torch.tensor([0.5070, 0.4865, 0.4408])
stdev = torch.tensor([0.2621, 0.2512, 0.2713])

In [44]:
train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        CIFAR10Policy(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
    ])

#train_dataset, train_valid_dataset = [torchvision.datasets.ImageFolder(folder, transform=train_transforms) for folder in [root/'train', root/'train_valid']]
train_dataset = torchvision.datasets.CIFAR100(root="data",
                                             train=True,
                                             download=True,
                                             transform=train_transforms)

valid_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((224,224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean, stdev)
    ])

#valid_dataset, test_dataset = [torchvision.datasets.ImageFolder(folder, transform=valid_transforms) for folder in [root/'valid', root/'test']]

valid_dataset = torchvision.datasets.CIFAR100(root="data",
                                            train=False,
                                            download=True,
                                            transform=valid_transforms)

Files already downloaded and verified
Files already downloaded and verified


In [45]:
num_gpus = torch.cuda.device_count()

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2*num_gpus, pin_memory=True)
#train_valid_dataloader = torch.utils.data.DataLoader(train_valid_dataset, batch_size=128, shuffle=True, num_workers=2*num_gpus, pin_memory=True)

valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=2*num_gpus, pin_memory=True)
#test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2*num_gpus, pin_memory=True)

In [46]:
def get_net():
    resnet = torchvision.models.resnet34(pretrained=True)
    
    # Substitute the FC output layer
    resnet.fc = torch.nn.Linear(resnet.fc.in_features, 100)
    torch.nn.init.xavier_uniform_(resnet.fc.weight)
    return resnet

In [47]:
import time

def train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler=None, epochs=10, device='cpu', checkpoint_epochs=10):
    start = time.time()
    print(f'Training for {epochs} epochs on {device}')
    
    for epoch in range(1,epochs+1):
        epoch_start = time.time()
        print(f"Epoch {epoch}/{epochs}")
        
        net.train()  
        train_loss = torch.tensor(0., device=device)  
        train_accuracy = torch.tensor(0., device=device)
        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            preds = net(X)
            loss = criterion(preds, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                train_loss += loss * train_dataloader.batch_size
                train_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if valid_dataloader is not None:
            net.eval()  
            valid_loss = torch.tensor(0., device=device)
            valid_accuracy = torch.tensor(0., device=device)
            with torch.no_grad():
                for X, y in valid_dataloader:
                    X = X.to(device)
                    y = y.to(device)
                    preds = net(X)
                    loss = criterion(preds, y)

                    valid_loss += loss * valid_dataloader.batch_size
                    valid_accuracy += (torch.argmax(preds, dim=1) == y).sum()
        
        if scheduler is not None: 
            scheduler.step()
            
        print(f'Training loss: {train_loss/len(train_dataloader.dataset):.2f}')
        print(f'Training accuracy: {100*train_accuracy/len(train_dataloader.dataset):.2f}')
        
        if valid_dataloader is not None:
            print(f'Valid loss: {valid_loss/len(valid_dataloader.dataset):.2f}')
            print(f'Valid accuracy: {100*valid_accuracy/len(valid_dataloader.dataset):.2f}')
        
        if epoch % checkpoint_epochs == 0:
            torch.save({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, f'./checkpoint_epoch{epoch}.pth.tar')
        elapsed = time.time()
        print(f'Epoch time: {elapsed-epoch_start:.1f} Total training time: {elapsed-start:.1f}')
        print()
    
    end = time.time()
    print(f'Total training time: {end-start:.1f} seconds')
    return net

In [61]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

    num_zeros = 0
    num_elements = 0

    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [62]:
def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

In [48]:
import torch.nn.utils.prune as prune

In [63]:
net = torchvision.models.resnet34(num_classes=10, pretrained=False)



In [64]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [65]:
parameters_to_prune = []
for module_name, module in net.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        parameters_to_prune.append((module, "weight"))
prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=0,
            )

In [66]:
net.load_state_dict(torch.load("/kaggle/input/resnet34-for-cifar10/pytorch/sparsity0.95_accuracy0.96_with_masks/1/sparsity0.95_final_acc0.96 (1).pt",
                               map_location=device))

<All keys matched successfully>

In [67]:
net.fc = torch.nn.Linear(512, 100, bias=True)

In [68]:
torch.nn.init.xavier_uniform_(net.fc.weight)

Parameter containing:
tensor([[-0.0295, -0.0525,  0.0830,  ...,  0.0026,  0.0546,  0.0394],
        [ 0.0718,  0.0655, -0.0764,  ..., -0.0309,  0.0945, -0.0495],
        [ 0.0491,  0.0332, -0.0281,  ..., -0.0881,  0.0074,  0.0225],
        ...,
        [-0.0461, -0.0546, -0.0443,  ..., -0.0958,  0.0205,  0.0218],
        [-0.0324,  0.0385, -0.0196,  ..., -0.0204,  0.0054,  0.0113],
        [ 0.0015, -0.0329,  0.0705,  ...,  0.0321, -0.0976,  0.0513]],
       requires_grad=True)

In [69]:
measure_global_sparsity(net, conv2d_use_mask=True)

(20204266, 21318848, 0.9477184695908522)

In [70]:
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr, weight_decay, epochs = 1e-3, 5e-4, 30

#net = get_net().to(device)


criterion = torch.nn.CrossEntropyLoss()


#params_1x = [param for name, param in net.named_parameters() if 'fc' not in str(name)]
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20])

In [76]:
!nvidia-smi

Wed May  8 14:54:59 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0              32W / 250W |   3136MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [74]:
import gc

In [75]:
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()

In [79]:
net.to(device);

In [80]:
net = train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, epochs, device)

Training for 30 epochs on cuda
Epoch 1/30
Training loss: 2.33
Training accuracy: 40.01
Valid loss: 1.58
Valid accuracy: 53.94
Epoch time: 130.7 Total training time: 130.7

Epoch 2/30
Training loss: 1.57
Training accuracy: 55.90
Valid loss: 1.39
Valid accuracy: 59.76
Epoch time: 130.9 Total training time: 261.6

Epoch 3/30
Training loss: 1.35
Training accuracy: 61.71
Valid loss: 1.30
Valid accuracy: 62.88
Epoch time: 130.8 Total training time: 392.4

Epoch 4/30
Training loss: 1.23
Training accuracy: 65.05
Valid loss: 1.19
Valid accuracy: 65.85
Epoch time: 130.7 Total training time: 523.1

Epoch 5/30
Training loss: 1.13
Training accuracy: 67.43
Valid loss: 1.18
Valid accuracy: 65.52
Epoch time: 130.7 Total training time: 653.9

Epoch 6/30
Training loss: 1.06
Training accuracy: 69.38
Valid loss: 1.09
Valid accuracy: 68.23
Epoch time: 131.3 Total training time: 785.2

Epoch 7/30
Training loss: 1.01
Training accuracy: 70.90
Valid loss: 1.10
Valid accuracy: 68.48
Epoch time: 131.2 Total trai

In [81]:
scheduler1 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1, 15])

In [82]:
net = train(net, train_dataloader, valid_dataloader, criterion, optimizer, scheduler1, 25, device)

Training for 25 epochs on cuda
Epoch 1/25
Training loss: 0.28
Training accuracy: 93.02
Valid loss: 0.84
Valid accuracy: 76.58
Epoch time: 130.7 Total training time: 130.7

Epoch 2/25
Training loss: 0.26
Training accuracy: 93.65
Valid loss: 0.82
Valid accuracy: 76.87
Epoch time: 130.6 Total training time: 261.3

Epoch 3/25
Training loss: 0.25
Training accuracy: 93.71
Valid loss: 0.82
Valid accuracy: 77.01
Epoch time: 130.6 Total training time: 391.9

Epoch 4/25
Training loss: 0.25
Training accuracy: 93.79
Valid loss: 0.82
Valid accuracy: 77.04
Epoch time: 130.6 Total training time: 522.5

Epoch 5/25
Training loss: 0.24
Training accuracy: 94.00
Valid loss: 0.83
Valid accuracy: 77.01
Epoch time: 130.6 Total training time: 653.1

Epoch 6/25
Training loss: 0.25
Training accuracy: 93.89
Valid loss: 0.83
Valid accuracy: 76.89
Epoch time: 130.9 Total training time: 784.0

Epoch 7/25
Training loss: 0.25
Training accuracy: 93.93
Valid loss: 0.83
Valid accuracy: 77.05
Epoch time: 130.7 Total trai

In [83]:
def create_classification_report(model, device, test_loader):

    model.eval()
    model.to(device)

    y_pred = []
    y_true = []

    with torch.no_grad():
        for data in test_loader:
            y_true += data[1].numpy().tolist()
            images, _ = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            y_pred += predicted.cpu().numpy().tolist()

    classification_report = sklearn.metrics.classification_report(
        y_true=y_true, y_pred=y_pred)

    return classification_report

In [84]:
print(create_classification_report(net, device, valid_dataloader))

              precision    recall  f1-score   support

           0       0.87      0.92      0.89       100
           1       0.93      0.89      0.91       100
           2       0.57      0.64      0.60       100
           3       0.64      0.55      0.59       100
           4       0.60      0.63      0.61       100
           5       0.80      0.81      0.81       100
           6       0.78      0.75      0.77       100
           7       0.77      0.75      0.76       100
           8       0.90      0.91      0.91       100
           9       0.87      0.88      0.88       100
          10       0.76      0.70      0.73       100
          11       0.47      0.45      0.46       100
          12       0.88      0.82      0.85       100
          13       0.78      0.69      0.73       100
          14       0.80      0.73      0.76       100
          15       0.75      0.82      0.78       100
          16       0.79      0.78      0.78       100
          17       0.88    

In [87]:
measure_global_sparsity(net, conv2d_use_mask=True)

(20204266, 21318848, 0.9477184695908522)

In [88]:
torch.save({
                #'epoch': epoch,
                'state_dict': net.state_dict(),
                #'optimizer': optimizer.state_dict(),
            }, 'sp0.95_acc0.77_finetuned_from_cifar10.pth.tar')

In [None]:
model = torchvision.models.resnet34()
model.fc = torch.nn.Linear(model.fc.in_features, 100)

In [None]:
checkpoint = torch.load('/kaggle/working/checkpoint_epoch30.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

In [None]:
print(create_classification_report(model, device, valid_dataloader))