# CNN classifier

Congratulations, here's your first homework! You'll learn the art of training deep image classifiers. You might remember `03 seminar` about training CIFAR10 classifier, this homework is also about training **CIFAR10 classifier**, but this time you'll have to do it on your own and with some extra features.

## Data
Your dataset is CIFAR10. Check out `03 seminar` on how to load train and val data splits.

**Note:** for training you can only use `train` dataset.

## Game rules:
Maximum score you can get for this task is **10.0**.

Half of 10 points you can get by reaching high val accuracy (as listed in table below):

- accuracy > 60.0 -> **1 point**
- accuracy > 70.0 -> **2 points**
- accuracy > 80.0 -> **3 points**
- accuracy > 90.0 -> **4 points**
- accuracy > 92.5 -> **5 points**

Another half of 10 points you can get by adding following features to your training pipeline. It's okay if you see some technics for the first time (that was the idea). Feel free to google and dive into topic on your own, it's homework after all:
1. Data augmentations. Check out [this article](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced) (**1 point**)
2. [LR schedule](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau) (**0.5 point**)
3. Finetune pretrained model from [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html) (except AlexNet!) (**1 point**)
4. Implement [ResNet model](https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624) (**2 points**)
5. Use of [tensorboardX](https://github.com/lanpa/tensorboardX) to monitor training process (**0.5 points**)

As a result you have to submit **notebook with working code** (results will be reproduced during homework cheking) and **short report** (write it in the same notebook) about things you tried and what tasks you managed to implement. Good luck and have fun!

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

from tqdm import tqdm_notebook as tqdm

from tensorboardX import SummaryWriter
from datetime import datetime
import os

In [2]:
batch_size = 4
device = torch.device('cuda')#torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print("len(train_dataset) =", len(train_dataset))

val_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
val_dataloader= torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print("len(val_dataset) =", len(val_dataset))

Files already downloaded and verified
len(train_dataset) = 50000
Files already downloaded and verified
len(val_dataset) = 10000


In [3]:
model = torchvision.models.resnet50(num_classes=10).to(device)

In [4]:
#for feature in model.parameters():
#    feature.requires_grad = False

model.fc = nn.Sequential(nn.Linear(8192, 512), nn.ReLU(inplace=True), nn.Linear(512, 10)).to(device)

In [5]:
criterion = nn.CrossEntropyLoss().to(device)
opt = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.1, patience=4)

In [8]:
experiment_title = 'resnet101_cifar10'
experiment_name = "{}@{}".format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
writer = SummaryWriter(log_dir=os.path.join("./tb_untr", experiment_name))

In [12]:
n_epochs = 25
n_epochs_init = 15
n_iters_total = 187500

for epoch in range(n_epochs_init, n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    correct = 0
    
    model.train()
    for batch in tqdm(train_dataloader):
        # unpack batch
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        
        # forward
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_train_loss += loss.item()
        
        # optimize
        opt.zero_grad()
        loss.backward()
        opt.step()
        # dump statistics
        writer.add_scalar("train/loss", loss.item(), global_step=n_iters_total)
        
        n_iters_total += 1
        
    print("Epoch {} done, total train loss {}.".format(epoch, total_train_loss / len(train_dataset)))
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            image_batch, label_batch = batch
            image_batch, label_batch = image_batch.to(device), label_batch.to(device)
            outputs = model(image_batch)
            loss = criterion(outputs, label_batch)
            total_val_loss += loss
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == label_batch).sum().item()
    print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))
    
    scheduler.step(total_val_loss)
    
    if epoch % 2 == 0:
        torch.save({'model': model.state_dict(), 
                    'optimizer': opt.state_dict(), 
                    'epoch': epoch, 
                    'accuracy': 100 * correct / len(val_dataset)
                   }, 
                   'checkpoints/resnet101_{}epochs.pth'.format(epoch))

HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 15 done, total train loss 0.003052017563506961.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.19%, total val loss 0.1471068263053894


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 16 done, total train loss 0.00233458678945899.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 85.86%, total val loss 0.15376444160938263


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 17 done, total train loss 0.00201752926774323.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.0%, total val loss 0.16032493114471436


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 18 done, total train loss 0.001225614498257637.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.15%, total val loss 0.1539306789636612


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 19 done, total train loss 0.0010847580142319202.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.46%, total val loss 0.14959123730659485


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 20 done, total train loss 0.0009198202818632126.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.3%, total val loss 0.15434479713439941


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 21 done, total train loss 0.0009687646854668856.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.25%, total val loss 0.1575690656900406


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 22 done, total train loss 0.0008647281095385551.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.14%, total val loss 0.15694677829742432


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 23 done, total train loss 0.0008671411618590355.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.02%, total val loss 0.15670767426490784


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 24 done, total train loss 0.0008437503151595592.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.32%, total val loss 0.15133218467235565


In [31]:
model.eval()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_val_loss += loss
        predicted = torch.argmax(outputs, dim=1)
        correct += (predicted == label_batch).sum().item()
print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))

HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 70.92%, total val loss 0.0002574552781879902


In [45]:
torch.save({'model': model.state_dict(), 'optimizer': opt.state_dict(), 'epoch': 5, 'loss': loss}, 'checkpoints/resnet50_pre_5epochs.pth')

In [10]:
for param_group in opt.param_groups:
    print(param_group['lr'])

1e-05


In [10]:
n_iters_total

187500

In [6]:
transform_random = transforms.Compose([
    transforms.RandomChoice([transforms.RandomCrop(28), transforms.RandomHorizontalFlip(p=0.75)]),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_random
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print("len(train_dataset) =", len(train_dataset))

val_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
val_dataloader= torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print("len(val_dataset) =", len(val_dataset))

Files already downloaded and verified
len(train_dataset) = 50000
Files already downloaded and verified
len(val_dataset) = 10000


In [7]:
checkpoint = torch.load('checkpoints/resnet101_12epochs.pth')
model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['optimizer'])

In [8]:
experiment_title = 'resnet50_cont_cifar10'
experiment_name = "{}@{}".format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
writer = SummaryWriter(log_dir=os.path.join("./tb_untr", experiment_name))

In [9]:
n_epochs = 22
n_epochs_init = 18
n_iters_total = 150000

for epoch in range(n_epochs_init, n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    correct = 0
    
    model.train()
    for batch in tqdm(train_dataloader):
        # unpack batch
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        
        # forward
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_train_loss += loss.item()
        
        # optimize
        opt.zero_grad()
        loss.backward()
        opt.step()
        # dump statistics
        writer.add_scalar("train/loss", loss.item(), global_step=n_iters_total)
        
        n_iters_total += 1
        
    print("Epoch {} done, total train loss {}.".format(epoch, total_train_loss / len(train_dataset)))
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            image_batch, label_batch = batch
            image_batch, label_batch = image_batch.to(device), label_batch.to(device)
            outputs = model(image_batch)
            loss = criterion(outputs, label_batch)
            total_val_loss += loss
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == label_batch).sum().item()
    print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))
    
    scheduler.step(total_val_loss)
    
    if epoch % 2 == 0:
        torch.save({'model': model.state_dict(), 
                    'optimizer': opt.state_dict(), 
                    'epoch': epoch, 
                    'accuracy': 100 * correct / len(val_dataset)
                   }, 
                   'checkpoints/resnet50_cont_{}epochs.pth'.format(epoch))

HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 13 done, total train loss 0.0851934731952846.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.13%, total val loss 0.11438147723674774


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 14 done, total train loss 0.07495894526660442.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 85.58%, total val loss 0.11662707477807999


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 15 done, total train loss 0.06751117350965738.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.78%, total val loss 0.10897631198167801


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 16 done, total train loss 0.06459670354261994.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.97%, total val loss 0.10363928973674774


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))


Epoch 17 done, total train loss 0.06001895448833704.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))


Accuracy 86.92%, total val loss 0.10566907376050949


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

KeyboardInterrupt: 