# CNN classifier

Congratulations, here's your first homework! You'll learn the art of training deep image classifiers. You might remember `03 seminar` about training CIFAR10 classifier, this homework is also about training **CIFAR10 classifier**, but this time you'll have to do it on your own and with some extra features.

## Data
Your dataset is CIFAR10. Check out `03 seminar` on how to load train and val data splits.

**Note:** for training you can only use `train` dataset.

## Game rules:
Maximum score you can get for this task is **10.0**.

Half of 10 points you can get by reaching high val accuracy (as listed in table below):

- accuracy > 60.0 -> **1 point**
- accuracy > 70.0 -> **2 points**
- accuracy > 80.0 -> **3 points**
- accuracy > 90.0 -> **4 points**
- accuracy > 92.5 -> **5 points**

Another half of 10 points you can get by adding following features to your training pipeline. It's okay if you see some technics for the first time (that was the idea). Feel free to google and dive into topic on your own, it's homework after all:
1. Data augmentations. Check out [this article](https://medium.com/nanonets/how-to-use-deep-learning-when-you-have-limited-data-part-2-data-augmentation-c26971dc8ced) (**1 point**)
2. [LR schedule](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau) (**0.5 point**)
3. Finetune pretrained model from [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html) (except AlexNet!) (**1 point**)
4. Implement [ResNet model](https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624) (**2 points**)
5. Use of [tensorboardX](https://github.com/lanpa/tensorboardX) to monitor training process (**0.5 points**)

As a result you have to submit **notebook with working code** (results will be reproduced during homework cheking) and **short report** (write it in the same notebook) about things you tried and what tasks you managed to implement. Good luck and have fun!

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

from tqdm import tqdm_notebook as tqdm
from tensorboardX import SummaryWriter
from apex import amp

from datetime import datetime
import os

In [2]:
batch_size = 16
device = torch.device('cuda')#torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_random = transforms.Compose([
    transforms.RandomChoice([transforms.RandomCrop(28),
                             transforms.RandomHorizontalFlip(p=0.75),
                             transforms.RandomAffine(15)
                            ]),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_random
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print("len(train_dataset) =", len(train_dataset))

val_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
val_dataloader= torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print("len(val_dataset) =", len(val_dataset))

Files already downloaded and verified
len(train_dataset) = 50000
Files already downloaded and verified
len(val_dataset) = 10000


In [11]:
model = torchvision.models.resnet50(num_classes=10).to(device)

In [12]:
#for feature in model.parameters():
#    feature.requires_grad = False

model.fc = nn.Sequential(nn.Linear(2048, 512), nn.ReLU(inplace=True), nn.Linear(512, 10)).to(device)

In [13]:
#model.half()  # convert to half precision
#for layer in model.modules():
#  if isinstance(layer, nn.BatchNorm2d):
#    layer.float()

In [14]:
criterion = nn.CrossEntropyLoss().to(device)
opt = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.1, patience=4)

In [15]:
model, opt = amp.initialize(model, opt, opt_level='O1')

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [10]:
experiment_title = 'resnet50_fp16_batch16_v2'
experiment_name = "{}@{}".format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
writer = SummaryWriter(log_dir=os.path.join("./tb_bench", experiment_name))

In [9]:
#checkpoint = torch.load('checkpoints/resnet50_aug_adam_v2_6epochs.pth')
#model.load_state_dict(checkpoint['model'])
#opt.load_state_dict(checkpoint['optimizer'])

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/resnet50_aug_adam_v2_6epochs.pth'

In [16]:
n_epochs = 40
n_epochs_init = 0
n_iters_total = 0

for epoch in range(n_epochs_init, n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    correct = 0
    
    model.train()
    for batch in tqdm(train_dataloader):
        # unpack batch
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        
        # forward
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_train_loss += loss.item()
        
        # optimize
        opt.zero_grad()
        with amp.scale_loss(loss, opt) as scaled_loss:
            scaled_loss.backward()
        opt.step()
        # dump statistics
        if int(n_iters_total) == n_iters_total:
            writer.add_scalar("train/loss", loss.item(), global_step=n_iters_total)
        
        n_iters_total += 1/4
        
    print("Epoch {} done, total train loss {}.".format(epoch, total_train_loss / len(train_dataset)))
    
    model.eval()
    with torch.no_grad():
        val_n = 0
        for batch in tqdm(val_dataloader):
            image_batch, label_batch = batch
            image_batch, label_batch = image_batch.to(device).half(), label_batch.to(device)
            outputs = model(image_batch)
            loss = criterion(outputs, label_batch)
            total_val_loss += loss
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == label_batch).sum().item()
            writer.add_scalar("val/loss", loss.item(), global_step=n_iters_total+val_n)
            val_n += 1
        writer.add_scalar("val/accuracy", 100 * correct / len(val_dataset), global_step=n_iters_total)
    print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))
    
    scheduler.step(total_val_loss)
    
    if epoch % 2 == 0:
        torch.save({'model': model.state_dict(), 
                    'optimizer': opt.state_dict(), 
                    'epoch': epoch,
                    'iter_num': n_iters_total,
                    'loss': loss,
                    'accuracy': 100 * correct / len(val_dataset)
                   }, 
                   'checkpoints/resnet50_aug_adam_fp16_batch16_{}epochs.pth'.format(epoch))

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Epoch 0 done, total train loss 0.09770424658179283.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 56.87%, total val loss 0.07408476620912552


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 1 done, total train loss 0.0665425976896286.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 71.91%, total val loss 0.05090172961354256


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 2 done, total train loss 0.05202948563873768.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 76.01%, total val loss 0.042960163205862045


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 3 done, total train loss 0.044175849616229536.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 78.97%, total val loss 0.039128370583057404


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 4 done, total train loss 0.03880635402172804.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 81.09%, total val loss 0.034825682640075684


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Epoch 5 done, total train loss 0.03484663666963577.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 82.21%, total val loss 0.03216952830553055


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Epoch 6 done, total train loss 0.03161643602013588.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 84.8%, total val loss 0.028099147602915764


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 7 done, total train loss 0.0291702418589592.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 85.82%, total val loss 0.02584516443312168


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 8 done, total train loss 0.026794010414779187.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 85.13%, total val loss 0.02722126431763172


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 9 done, total train loss 0.02481424922913313.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 87.08%, total val loss 0.02410372532904148


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 10 done, total train loss 0.023157953354418278.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 88.0%, total val loss 0.02224443480372429


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 11 done, total train loss 0.021675165048539637.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 88.69%, total val loss 0.021168410778045654


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 12 done, total train loss 0.019873110725879668.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 87.63%, total val loss 0.02344292588531971


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 13 done, total train loss 0.01891987650513649.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 89.04%, total val loss 0.020619351416826248


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 14 done, total train loss 0.017654934478998185.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 88.96%, total val loss 0.02108781971037388


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 15 done, total train loss 0.016683078394085168.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 89.31%, total val loss 0.020399028435349464


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 16 done, total train loss 0.015866662561893462.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 89.33%, total val loss 0.019480621442198753


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 17 done, total train loss 0.014910211726278067.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 88.59%, total val loss 0.02151712030172348


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 18 done, total train loss 0.014459049625694751.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 89.7%, total val loss 0.01853351481258869


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 19 done, total train loss 0.013408398766219616.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.08%, total val loss 0.019246380776166916


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 20 done, total train loss 0.012641025468409062.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 89.44%, total val loss 0.02051013521850109


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 21 done, total train loss 0.012257171808630228.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.49%, total val loss 0.017900031059980392


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 22 done, total train loss 0.01136088792309165.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.75%, total val loss 0.018516214564442635


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 23 done, total train loss 0.011014521453678608.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.47%, total val loss 0.01899869181215763


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 24 done, total train loss 0.01039862813398242.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 91.03%, total val loss 0.01729872263967991


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 25 done, total train loss 0.010083133695870638.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.82%, total val loss 0.018989454954862595


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 26 done, total train loss 0.009768415744304657.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 91.68%, total val loss 0.016763459891080856


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 27 done, total train loss 0.00915154215618968.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.82%, total val loss 0.018354810774326324


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 28 done, total train loss 0.008768658512681722.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.86%, total val loss 0.018550671637058258


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 29 done, total train loss 0.008407255227565765.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 90.91%, total val loss 0.01833920180797577


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 30 done, total train loss 0.008191901792436838.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 91.41%, total val loss 0.016937391832470894


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 31 done, total train loss 0.00786987985894084.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 91.29%, total val loss 0.017832191661000252


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 32 done, total train loss 0.004115045924931764.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.18%, total val loss 0.013910542242228985


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Epoch 33 done, total train loss 0.0028368391978740694.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.25%, total val loss 0.014203106053173542


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 34 done, total train loss 0.002462821400910616.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.44%, total val loss 0.014122072607278824


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 35 done, total train loss 0.0021515168383717535.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.5%, total val loss 0.014266997575759888


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Epoch 36 done, total train loss 0.0020710361909121273.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.26%, total val loss 0.014725733548402786


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 37 done, total train loss 0.0018698702447116375.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.4%, total val loss 0.01465950720012188


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 38 done, total train loss 0.0017189810525625943.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.56%, total val loss 0.0145795326679945


HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0
Epoch 39 done, total train loss 0.001473579454049468.


HBox(children=(IntProgress(value=0, max=625), HTML(value='')))

Accuracy 93.61%, total val loss 0.01441584900021553


In [28]:
list(model.parameters())

[Parameter containing:
 tensor([[[[-1.5144e-02,  2.7527e-02,  2.4628e-02,  ..., -1.3779e-02,
            -2.8397e-02, -1.9791e-02],
           [ 2.9648e-02,  1.5320e-02,  1.7624e-02,  ..., -1.1435e-03,
            -5.8060e-03, -1.9379e-02],
           [ 2.1164e-02,  3.5309e-02, -6.9542e-03,  ...,  1.6815e-02,
             2.7771e-02, -6.8176e-02],
           ...,
           [-2.2873e-02, -3.1433e-02,  9.2163e-03,  ...,  3.9093e-02,
             4.5532e-02, -5.5023e-02],
           [ 3.0014e-02, -1.4214e-02,  5.7144e-03,  ..., -8.7280e-03,
             1.0254e-02, -1.6623e-03],
           [-1.1051e-04, -4.8553e-02,  5.7617e-02,  ...,  1.3428e-02,
            -2.8061e-02, -2.9785e-02]],
 
          [[ 1.7212e-02, -4.2229e-03,  5.0018e-02,  ...,  1.8148e-03,
             1.4359e-02, -1.6144e-02],
           [ 1.8265e-02,  2.7756e-02,  1.2924e-02,  ..., -5.0476e-02,
            -1.6129e-02,  3.7292e-02],
           [ 3.2013e-02,  7.1945e-03, -1.7014e-02,  ..., -3.5217e-02,
             3.9

In [31]:
model.eval()
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_val_loss += loss
        predicted = torch.argmax(outputs, dim=1)
        correct += (predicted == label_batch).sum().item()
print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))

HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 70.92%, total val loss 0.0002574552781879902


In [45]:
torch.save({'model': model.state_dict(), 'optimizer': opt.state_dict(), 'epoch': 5, 'loss': loss}, 'checkpoints/resnet50_pre_5epochs.pth')

In [31]:
for param_group in opt.param_groups:
    param_group['lr'] = 1e-06

In [22]:
n_iters_total

550001

In [6]:
transform_random = transforms.Compose([
    transforms.RandomChoice([transforms.RandomCrop(28), transforms.RandomHorizontalFlip(p=0.75)]),
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform_random
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print("len(train_dataset) =", len(train_dataset))

val_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)
val_dataloader= torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print("len(val_dataset) =", len(val_dataset))

Files already downloaded and verified
len(train_dataset) = 50000
Files already downloaded and verified
len(val_dataset) = 10000


In [29]:
checkpoint = torch.load('checkpoints/resnet50_aug_adam_v2_40epochs.pth')
model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['optimizer'])

In [30]:
experiment_title = 'resnet50_cont_aug_adam_3_cifar10'
experiment_name = "{}@{}".format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
writer = SummaryWriter(log_dir=os.path.join("./tb_untr", experiment_name))

In [32]:
n_epochs = 46
n_epochs_init = 41
n_iters_total = 500000

for epoch in range(n_epochs_init, n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    correct = 0
    
    model.train()
    for batch in tqdm(train_dataloader):
        # unpack batch
        image_batch, label_batch = batch
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
        
        # forward
        outputs = model(image_batch)
        loss = criterion(outputs, label_batch)
        total_train_loss += loss.item()
        
        # optimize
        opt.zero_grad()
        loss.backward()
        opt.step()
        # dump statistics
        writer.add_scalar("train/loss", loss.item(), global_step=n_iters_total)
        
        n_iters_total += 1
        
    print("Epoch {} done, total train loss {}.".format(epoch, total_train_loss / len(train_dataset)))
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            image_batch, label_batch = batch
            image_batch, label_batch = image_batch.to(device), label_batch.to(device)
            outputs = model(image_batch)
            loss = criterion(outputs, label_batch)
            total_val_loss += loss
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == label_batch).sum().item()
    print("Accuracy {:.4}%, total val loss {}".format(100 * correct / len(val_dataset), total_val_loss / len(val_dataset)))
    
    #scheduler.step(total_val_loss)
    
    if epoch % 2 == 0:
        torch.save({'model': model.state_dict(), 
                    'optimizer': opt.state_dict(), 
                    'epoch': epoch, 
                    'accuracy': 100 * correct / len(val_dataset)
                   }, 
                   'checkpoints/resnet50_aug_adam_v2_cont2_{}epochs.pth'.format(epoch))

HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

Epoch 41 done, total train loss 0.006137742601714563.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 92.19%, total val loss 0.0823383778333664


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

Epoch 42 done, total train loss 0.005276410705707967.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 91.78%, total val loss 0.08470028638839722


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

Epoch 43 done, total train loss 0.005804890033230185.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 92.05%, total val loss 0.08246500045061111


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

Epoch 44 done, total train loss 0.005602258243680699.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 92.21%, total val loss 0.07993371039628983


HBox(children=(IntProgress(value=0, max=12500), HTML(value='')))

Epoch 45 done, total train loss 0.005522879205830395.


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Accuracy 92.01%, total val loss 0.08301439136266708
