In [1]:
%config Completer.use_jedi = False
# %load_ext autoreload
# %autoreload 2

In [2]:
import os
import random
from tqdm import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision

from src import utils
from src import pytorch_utils as ptu
from config import cfg

import warnings
warnings.filterwarnings("ignore")

In [3]:
cfg.tqdm_bar = True
cfg.prints = 'display'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', device)

device cpu


In [25]:
dataset = torchvision.datasets.CIFAR10(root=cfg.data_path, train=False)

In [5]:
versions = [
    'no_rotation_resnet34_adam_lr0.0003_bs32',
    'rotation_resnet34_adam_lr0.0003_bs32',
]
epoch = -1

In [6]:
# log = pd.DataFrame(columns=['model', 'model_epoch', 'img', 'label', 'pred', 'loss', 'augment', 'angle'])
log = pd.read_csv('results_log.csv')

In [7]:
log

Unnamed: 0,model,model_epoch,img,label,pred,score,loss,augment,angle
0,rotation_resnet34_adam_lr0.0003_bs32,-1,0,3,3,1,1.237721,rotate,0
1,rotation_resnet34_adam_lr0.0003_bs32,-1,1,8,8,1,0.417182,rotate,0
2,rotation_resnet34_adam_lr0.0003_bs32,-1,2,8,8,1,0.278200,rotate,0
3,rotation_resnet34_adam_lr0.0003_bs32,-1,3,0,3,0,2.661825,rotate,0
4,rotation_resnet34_adam_lr0.0003_bs32,-1,4,6,6,1,0.718933,rotate,0
...,...,...,...,...,...,...,...,...,...
199995,no_rotation_resnet34_adam_lr0.0003_bs32,-1,9995,8,0,0,1.072774,rotate,90
199996,no_rotation_resnet34_adam_lr0.0003_bs32,-1,9996,3,5,0,2.098405,rotate,90
199997,no_rotation_resnet34_adam_lr0.0003_bs32,-1,9997,5,5,1,1.231478,rotate,90
199998,no_rotation_resnet34_adam_lr0.0003_bs32,-1,9998,1,5,0,2.610535,rotate,90


In [8]:
def eval_model(log, versions, epoch, angles, device, batch_size=1):
    if not batch_size == 1:
        raise NotImplementedError
    for version in versions:
        checkpoint = ptu.load_model(device, version=version, models_dir=cfg.models_dir, epoch=epoch)
        checkpoint.model.eval()
        checkpoint.model.to(device)
        with torch.no_grad():
            for angle in angles:
                transforms = torchvision.transforms.Compose([
                    utils.RotateAngle(angles=(angle, )),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(utils.cifar10_mean, utils.cifar10_std),
                ])

                dataset = torchvision.datasets.CIFAR10(root=cfg.data_path, train=False, transform=transforms)

                loader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=batch_size,
                                                     num_workers=cfg.num_workers,
                                                     shuffle=False,
                                                     drop_last=True)

                pbar = tqdm(loader)
                pbar.set_description(f'version={version}, angle={angle}')
                for i, batch in enumerate(pbar):
                    loss, results, _ = checkpoint.batch_pass(device, batch)
                    row = {
                        'model': checkpoint.version,
                        'model_epoch': epoch,
                        'img': i,
                        'label': results['trues'][0],
                        'pred': results['preds'][0],
                        'score': int(results['preds'][0] == results['trues'][0]),
                        'loss': float(loss.data),
                        'augment': 'rotate',
                        'angle': angle,
                    }
                    log = log.append(pd.Series(row), ignore_index=True)
                log.to_csv('results_log.csv', index=False)
    return log

In [9]:
angles = tuple(range(0, 100, 10))
# angles = (0, )

In [10]:
log = eval_model(log, versions, epoch, angles, device, batch_size=1)

version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=0: 100%|██████████| 10000/10000 [02:29<00:00, 66.99it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=10: 100%|██████████| 10000/10000 [02:51<00:00, 58.42it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=20: 100%|██████████| 10000/10000 [03:06<00:00, 53.72it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=30: 100%|██████████| 10000/10000 [03:21<00:00, 49.52it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=40: 100%|██████████| 10000/10000 [03:43<00:00, 44.82it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=50: 100%|██████████| 10000/10000 [03:47<00:00, 43.95it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=60: 100%|██████████| 10000/10000 [03:52<00:00, 42.94it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=70: 100%|██████████| 10000/10000 [04:08<00:00, 40.30it/s]
version=no_rotation_resnet34_adam_lr0.0003_bs32, angle=80: 100%|██████████| 10000/10000 [04:26<00

In [None]:
def f(string):
    if string == 'rotation_resnet34_adam_lr0.0003_bs32':
        return True
    elif string == 'no_rotation_resnet34_adam_lr0.0003_bs32':
        return False
    else:
        raise
log['rotations_train'] = log['model'].apply(f)
log['class'] = log['label'].apply(lambda x: dataset.classes[x])

In [33]:
log.pivot_table(values=['loss', 'score'], index='angle', columns='rotations_train')

Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
angle,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
0,1.532003,1.88711,0.484,0.3977
10,1.819066,1.711054,0.3889,0.4149
20,2.042534,1.687543,0.2923,0.4234
30,2.138657,1.693869,0.2665,0.4113
40,2.16346,1.695644,0.2612,0.4142
50,2.208765,1.703492,0.2467,0.4127
60,2.272597,1.717968,0.2278,0.4089
70,2.30471,1.714091,0.2212,0.4036
80,2.283663,1.72286,0.2579,0.4083
90,2.168132,1.818375,0.2849,0.3819


In [34]:
log.pivot_table(values=['loss', 'score'], index='class', columns='rotations_train')

Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.123839,1.557869,0.3315,0.4765
automobile,2.051601,1.779763,0.2673,0.3446
bird,2.191143,1.97333,0.2275,0.3062
cat,2.047516,1.755063,0.2465,0.3891
deer,2.171195,1.956255,0.1941,0.3116
dog,2.31894,1.844591,0.2217,0.3344
frog,1.514352,1.60988,0.5785,0.5216
horse,2.441635,1.707918,0.2074,0.4291
ship,2.399218,1.577353,0.2113,0.5035
truck,1.674149,1.589984,0.4456,0.4603


In [39]:
for angle in angles:
    print(f'comparison on angle {angle} only')
    display(pd.pivot_table(log[log['angle'] == angle], values=['loss', 'score'], index='class', columns='rotations_train'))

comparison on angle 0 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,1.559209,1.856608,0.514,0.411
automobile,1.399124,2.005408,0.507,0.283
bird,1.590236,2.108033,0.441,0.3
cat,1.95177,1.774098,0.267,0.376
deer,1.614953,2.030758,0.457,0.333
dog,1.412443,1.802734,0.528,0.403
frog,1.435945,1.754875,0.562,0.509
horse,1.629953,1.670099,0.464,0.464
ship,1.336867,2.067672,0.569,0.456
truck,1.389529,1.800818,0.531,0.442


comparison on angle 10 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.005715,1.697559,0.39,0.407
automobile,1.570046,1.680444,0.41,0.351
bird,2.237139,2.091847,0.21,0.283
cat,2.0734,1.777841,0.245,0.385
deer,1.95046,1.939978,0.359,0.309
dog,1.613132,1.8125,0.431,0.351
frog,1.720813,1.70409,0.509,0.487
horse,1.790364,1.509837,0.395,0.494
ship,2.095101,1.58893,0.315,0.498
truck,1.134486,1.30751,0.625,0.584


comparison on angle 20 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.081824,1.519196,0.321,0.472
automobile,2.121861,1.693991,0.191,0.365
bird,2.429868,1.996378,0.138,0.279
cat,2.26504,1.777679,0.174,0.389
deer,2.14593,1.907097,0.185,0.333
dog,2.442902,1.896883,0.161,0.328
frog,1.368254,1.449292,0.639,0.597
horse,2.250328,1.721853,0.222,0.423
ship,2.053533,1.397535,0.284,0.576
truck,1.265805,1.51553,0.608,0.472


comparison on angle 30 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,1.958482,1.385009,0.35,0.54
automobile,2.111715,1.786256,0.207,0.345
bird,2.459344,1.848542,0.148,0.335
cat,2.184093,1.700698,0.212,0.417
deer,2.171518,1.903456,0.151,0.323
dog,2.776035,1.837853,0.092,0.33
frog,1.482592,1.60312,0.611,0.499
horse,2.442716,1.812054,0.198,0.367
ship,2.294589,1.435653,0.191,0.528
truck,1.505482,1.626047,0.505,0.429


comparison on angle 40 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.015148,1.440099,0.306,0.533
automobile,1.802202,1.651803,0.307,0.421
bird,2.227999,1.855186,0.22,0.347
cat,2.136244,1.691641,0.247,0.407
deer,2.371543,2.04356,0.084,0.261
dog,2.964338,1.744664,0.07,0.371
frog,1.536367,1.556821,0.591,0.529
horse,2.696247,1.795934,0.157,0.383
ship,2.389177,1.569476,0.148,0.461
truck,1.495331,1.607256,0.482,0.429


comparison on angle 50 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.248444,1.537175,0.235,0.469
automobile,1.860183,1.646432,0.317,0.416
bird,2.152221,1.776871,0.229,0.376
cat,2.004337,1.740738,0.286,0.388
deer,2.396249,2.075996,0.074,0.258
dog,2.914944,1.774316,0.071,0.363
frog,1.521686,1.5911,0.595,0.511
horse,2.694702,1.783641,0.165,0.389
ship,2.578057,1.532327,0.125,0.515
truck,1.716829,1.57632,0.37,0.442


comparison on angle 60 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.398006,1.428608,0.264,0.533
automobile,2.138331,1.778473,0.233,0.348
bird,2.346363,1.889392,0.163,0.332
cat,1.934139,1.759655,0.291,0.394
deer,2.41331,1.955635,0.07,0.308
dog,2.738788,1.907758,0.081,0.293
frog,1.420177,1.587112,0.631,0.519
horse,2.573678,1.732604,0.141,0.43
ship,2.83797,1.485322,0.095,0.508
truck,1.92521,1.655123,0.309,0.424


comparison on angle 70 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.508587,1.474447,0.253,0.479
automobile,2.360215,1.729162,0.168,0.342
bird,2.500081,2.053359,0.109,0.249
cat,1.960247,1.760698,0.238,0.39
deer,2.503818,1.877081,0.064,0.348
dog,2.534287,1.963668,0.099,0.266
frog,1.310046,1.516147,0.655,0.549
horse,2.682368,1.732774,0.137,0.431
ship,2.723774,1.450033,0.125,0.518
truck,1.963681,1.583539,0.364,0.464


comparison on angle 80 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.363115,1.584517,0.352,0.467
automobile,2.607536,1.767318,0.142,0.319
bird,2.300267,2.080532,0.219,0.271
cat,1.962306,1.76121,0.25,0.389
deer,2.176079,1.852744,0.237,0.34
dog,2.031201,1.915097,0.275,0.271
frog,1.722365,1.672774,0.494,0.5
horse,2.701211,1.562264,0.124,0.497
ship,3.031121,1.57788,0.083,0.494
truck,1.941429,1.454261,0.403,0.535


comparison on angle 90 only


Unnamed: 0_level_0,loss,loss,score,score
rotations_train,False,True,False,True
class,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
airplane,2.099855,1.655473,0.33,0.454
automobile,2.544794,2.058349,0.191,0.256
bird,1.667915,2.03316,0.398,0.29
cat,2.003581,1.806367,0.255,0.356
deer,1.968093,1.976246,0.26,0.303
dog,1.761334,1.790437,0.409,0.368
frog,1.62527,1.663468,0.498,0.516
horse,2.954783,1.75812,0.071,0.413
ship,2.651989,1.668699,0.178,0.481
truck,2.403705,1.773433,0.259,0.382
