In [None]:
%config Completer.use_jedi = False
# %load_ext autoreload
# %autoreload 2

In [2]:
import os
import random
from tqdm import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision

from src import utils
from src import pytorch_utils as ptu
from config import cfg

import warnings
warnings.filterwarnings("ignore")

In [3]:
cfg.tqdm_bar = True
cfg.prints = 'display'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', device)

device cuda


In [5]:
transforms = torchvision.transforms.Compose([
    utils.RotateAngle(angles=cfg.angles),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(utils.cifar10_mean, utils.cifar10_std),
])

In [7]:
train_dataset = torchvision.datasets.CIFAR10(root=cfg.data_path, train=True, transform=transforms)  # download=True ,
test_dataset = torchvision.datasets.CIFAR10(root=cfg.data_path, train=False, transform=transforms)  # download=True ,

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=cfg.bs,
                                           num_workers=cfg.num_workers,
                                           shuffle=True,
                                           drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=cfg.bs,
                                          num_workers=cfg.num_workers,
                                          shuffle=False,
                                          drop_last=True)

In [8]:
# cfg.load = None
# cfg.save = False
# cfg.optimizer = 'adam'

In [9]:
print(f'Loads {cfg.version}')
if cfg.load is not None and os.path.exists(os.path.join(cfg.models_dir, cfg.version, ptu.naming_scheme(cfg.version, epoch=model_epoch)) + '.pth'):
    checkpoint = ptu.load_model(device, version=cfg.version, models_dir=cfg.models_dir, epoch=model_epoch)
    if cfg.prints == 'display':
        display(checkpoint.log.sort_index(ascending=False).head(20))
    elif cfg.prints == 'print':
        print(checkpoint.log.sort_index(ascending=False).head(20))
else:
    if cfg.feature_extraction:
        model = nn.Linear(train_loader.dataset.tensors[0].shape[1], len(train_dataset.classes), bias=cfg.bias)
    else:
        model = vars(torchvision.models)[cfg.backbone](pretrained=True)
        for p in model.parameters():
            p.requires_grad = False
        model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes), bias=cfg.bias)
    model.to(device)
    
    if cfg.optimizer == 'sgd':
        optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad],
                                    lr=cfg.lr,
                                    momentum=cfg.optimizer_momentum,
                                    weight_decay=cfg.wd)
    elif cfg.optimizer == 'adam':
        optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad],
                                    lr=cfg.lr,
                                    weight_decay=cfg.wd)
    else:
        raise NotImplementedError
    
    criterion = nn.CrossEntropyLoss().to(device)
    
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              T_max=cfg.epochs,
                                                              eta_min=cfg.min_lr) if cfg.cos else None
    
    checkpoint = utils.MyCheckpoint(version=cfg.version,
                                    model=model,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    criterion=criterion,
                                    score=utils.accuracy_score,
                                    models_dir=cfg.models_dir,
                                    best_policy=cfg.best_policy,
                                    save=cfg.save,
                                   )
ptu.params(checkpoint.model)

Loads rotation_resnet34_adam_lr0.0003_bs32
Number of parameters 21289802 trainable 5130


In [None]:
checkpoint.train(train_loader=train_loader,
                 val_loader=test_loader,
                 train_epochs=int(max(0, cfg.epochs - checkpoint.get_log())),
                 optimizer_params=cfg.optimizer_params,
                 prints=cfg.prints,
                 epochs_save=cfg.epochs_save,
                 epochs_evaluate_train=cfg.epochs_evaluate_train,
                 epochs_evaluate_validation=cfg.epochs_evaluate_validation,
                 max_iterations_train=cfg.max_iterations,
                 max_iterations_val=cfg.max_iterations,
                 device=device,
                 tqdm_bar=cfg.tqdm_bar,
                 save=cfg.save,
                 save_log=cfg.save_log,
                )

In [None]:
# # all classes
# for angle in range(0, 91, 10):
#     print(f'Angle {angle}')
#     test_loader = dl.test_loader(data_dir=cfg.data_dir,
#                                  batch_size=cfg.bs,
#                                  augment=True,
#                                  angles=[angle])
#     loss, score, results = checkpoint.evaluate(loader=test_loader,
#                                                device=device,
#                                                tqdm_bar=True)
#     df = df.append({'model': 'randomRotation', 'augment': 'rotation', 'class': 'all', 'angle': angle, 'loss': loss, 'score': score},
#                    ignore_index=True)

In [None]:
# # by class by angle classes
# for class_name in utils.classDict.keys():
#     for angle in range(0, 91, 10):
#         print(f'Class {class_name}, Angle {angle}')
#         test_loader = dl.test_loader(data_dir=cfg.data_dir,
#                                      batch_size=cfg.bs,
#                                      augment=True,
#                                      angles=[angle],
#                                      class_name=class_name
#                                     )
#         loss, score, results = checkpoint.evaluate(loader=test_loader,
#                                                    device=device,
#                                                    tqdm_bar=True)
#         df = df.append({'model': 'randomRotation', 'augment': 'rotation', 'class': class_name, 'angle': angle, 'loss': loss, 'score': score},
#                        ignore_index=True)

In [None]:
# fig, axes = plt.subplots(figsize=(20,10),
#                          nrows=2, ncols=6)
# for (val, group), ax in zip(df[df['model'] == 'randomRotation'].groupby('class'), axes.flatten()):
#     group.plot(x='angle', y='loss', kind='bar', ax=ax, title=val, ylim=(0, 7))

In [None]:
# # df.groupby('class').plot.bar(x='angle', y='score', ylim=(0, 1), subplots=True)
# fig, axes = plt.subplots(figsize=(20,10),
#                          nrows=2, ncols=6)
# for (val, group), ax in zip(df[df['model'] == 'randomRotation'].groupby('class'), axes.flatten()):
#     group.plot(x='angle', y='score', kind='bar', ax=ax, title=val, ylim=(0, 1))