# Train ResNet on 100 class ImageNet

In [1]:
import os
import json
import PIL.Image
import random
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision

from src.data.ImageNet100ClassesDataset import ImageNet100ClassesDataset, prepare_dataloaders_ImageNet100ClassesDataset

from src.training.Trainer import Trainer

In [2]:
def seed_everything(seed_value=4995):
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything()

In [3]:
root_dir = 'Data/imagenet100classes'

In [4]:
config = {
    'lr': 0.001,
    'batch_size': 64,
    'weight_decay': 0.01
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def load_model(num_classes = 100, model_path = None, to_cuda = True):
    if not model_path:
        model = torchvision.models.resnet18(pretrained = True)
        input_feat = model.fc.in_features
        
        model.fc = nn.Linear(input_feat, num_classes)
        loaded_state_dict = False
    
    else:
        print("Loaded", model_path)
        model = torchvision.models.resnet18()
        input_feat = model.fc.in_features
        model.fc = nn.Linear(input_feat, num_classes)
        loaded_model = torch.load(model_path)
        model.load_state_dict(loaded_model['model_state_dict'])
        loaded_state_dict = True
        
    if to_cuda:
        model = model.to('cuda')
        
    return model, loaded_state_dict

## Train Classification head

In [None]:
resnet, loaded_state_dict = load_model()

for name, param in resnet.named_parameters():
    if 'fc' not in name:
        param.requires_grad = False

In [None]:
train_loader, val_loader = prepare_dataloaders_ImageNet100ClassesDataset(root_dir, batch_size = config['batch_size'])

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

loss_fn = nn.CrossEntropyLoss()

train_params = [p for p in resnet.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(train_params, lr = config['lr'], weight_decay = config['weight_decay'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3,5], gamma=0.1)

training_params = {
    'dataloaders': dataloaders,
    'optimizer': optimizer,
    'scheduler': scheduler
}

In [None]:
classes = list(train_loader.dataset.class_name_dict.values())

In [None]:
trainer = Trainer(resnet, loss_fn, classes, training_params, DEVICE, num_epochs = 6, model_name = 'resnet_100_imagenet', save_model = True, model_dir = 'models')

In [10]:
trainer.train()

Current Epoch: 0
Train Loop:
Batch: 0 of 2032. Loss: 4.971340656280518. Mean so far: 4.971340656280518. Mean of 100: 4.971340656280518
Batch: 20 of 2032. Loss: 3.8565421104431152. Mean so far: 4.382260992413475. Mean of 100: 4.382260992413475
Batch: 40 of 2032. Loss: 3.0460102558135986. Mean so far: 3.868049377348365. Mean of 100: 3.868049377348365
Batch: 60 of 2032. Loss: 2.0944831371307373. Mean so far: 3.43997694234379. Mean of 100: 3.43997694234379
Batch: 80 of 2032. Loss: 1.8737000226974487. Mean so far: 3.1182604130403497. Mean of 100: 3.1182604130403497
Batch: 100 of 2032. Loss: 1.678606629371643. Mean so far: 2.8528005156186547. Mean of 100: 2.831615114212036
Batch: 120 of 2032. Loss: 1.3443585634231567. Mean so far: 2.6433636590468983. Mean of 100: 2.278195219039917
Batch: 140 of 2032. Loss: 1.4672715663909912. Mean so far: 2.4709372427446623. Mean of 100: 1.8981212675571442
Batch: 160 of 2032. Loss: 1.0509557723999023. Mean so far: 2.319302809904821. Mean of 100: 1.6356915891

## Fine-tune network

In [6]:
resnet, loaded_state_dict = load_model(model_path = 'models/resnet_100_imagenet.pt')

for name, param in resnet.named_parameters():
    param.requires_grad = True

Loaded models/resnet_100_imagenet.pt


In [7]:
train_loader, val_loader = prepare_dataloaders_ImageNet100ClassesDataset(root_dir, batch_size = config['batch_size'])

dataloaders = {
    'train': train_loader,
    'val': val_loader
}

loss_fn = nn.CrossEntropyLoss()

train_params = [p for p in resnet.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(train_params, lr = 0.00001, weight_decay = config['weight_decay'])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3,5], gamma=0.1)

training_params = {
    'dataloaders': dataloaders,
    'optimizer': optimizer,
    'scheduler': scheduler
}

In [8]:
classes = list(train_loader.dataset.class_name_dict.values())

In [9]:
trainer = Trainer(resnet, loss_fn, classes, training_params, DEVICE, num_epochs = 6, model_name = 'resnet_100_imagenet_fine_tuned', save_model = True, model_dir = 'models')

In [10]:
trainer.train()

Current Epoch: 0
Train Loop:
Batch: 0 of 2032. Loss: 0.7929084300994873. Mean so far: 0.7929084300994873. Mean of 100: 0.7929084300994873
Batch: 100 of 2032. Loss: 0.5967435240745544. Mean so far: 0.6240629769197785. Mean of 100: 0.6223745223879814
Batch: 200 of 2032. Loss: 0.6304162740707397. Mean so far: 0.6015580545610456. Mean of 100: 0.5788280829787255
Batch: 300 of 2032. Loss: 0.5305232405662537. Mean so far: 0.59579783123593. Mean of 100: 0.5842197823524475
Batch: 400 of 2032. Loss: 0.4131864607334137. Mean so far: 0.5842891300556963. Mean of 100: 0.5496479395031929
Batch: 500 of 2032. Loss: 0.5630130171775818. Mean so far: 0.577194420520417. Mean of 100: 0.548744635283947
Batch: 600 of 2032. Loss: 0.3642585575580597. Mean so far: 0.5735786809104056. Mean of 100: 0.5554638254642487
Batch: 700 of 2032. Loss: 0.47939738631248474. Mean so far: 0.5714079341857817. Mean of 100: 0.5583617463707924
Batch: 800 of 2032. Loss: 0.5330355167388916. Mean so far: 0.5695001342323389. Mean of 1