In [None]:
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from livelossplot import PlotLosses
from train_model_it import train_model
%matplotlib inline

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor()
    ]),
    'aug1': transforms.Compose([
#         transforms.RandomResizedCrop()
        transforms.RandomHorizontalFlip(1),
        transforms.ToTensor()
    ]),
    'aug2': transforms.Compose([
        transforms.RandomRotation(15),
        transforms.ToTensor()
    ]),
#     'aug3': transforms.Compose([
#         transforms.RandomCrop(224),
#         transforms.ToTensor()
#     ]),
#     'aug4': transforms.Compose([
#         transforms.RandomCrop(224),
#         transforms.RandomHorizontalFlip(1),
#         transforms.ToTensor()
#     ]),
    'val': transforms.Compose([
        transforms.ToTensor()
    ]),
}

train_datasets =  datasets.ImageFolder('/disk/LHC/images/71_60to70/train', data_transforms['train'])
augmt1_datasets = datasets.ImageFolder('/disk/LHC/images/71_60to70/train_noise', data_transforms['train'])
# augmt2_datasets = datasets.ImageFolder('/disk/MTJ/images/71/val_noise', data_transforms['val'])
# augmt3_datasets = datasets.ImageFolder('tiny-imagenet-200-256/train', data_transforms['aug3'])
# augmt4_datasets = datasets.ImageFolder('tiny-imagenet-200-256/train', data_transforms['aug4'])
valid_datasets =  datasets.ImageFolder('/disk/LHC/images/71_60to70/val',   data_transforms['val'])

concat = torch.utils.data.ConcatDataset([
    train_datasets,
    augmt1_datasets,
#     augmt2_datasets,
#     augmt3_datasets,
#     augmt4_datasets,
])

# concat2 = torch.utils.data.ConcatDataset([
#     valid_datasets,
# #     augmt2_datasets,
# #     augmt3_datasets,
# #     augmt4_datasets,
# ])
batch_size = 1000
dataloaders = {
    'train' : torch.utils.data.DataLoader(concat, batch_size=batch_size, shuffle=True, num_workers=100),
    'val'   : torch.utils.data.DataLoader(valid_datasets, batch_size=100, shuffle=False, num_workers=100)
}

dataset_sizes = {
    'train' : len(concat),
    'val'   : len(valid_datasets)
}

In [None]:
#Load Resnet18 with pretrained weights
model_ft = models.resnet18()
#Finetune Final few layers to adjust for tiny imagenet input
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 200)
model_ft.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
model_ft.maxpool = nn.Sequential()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
#Multi GPU
model_ft = torch.nn.DataParallel(model_ft, device_ids=[0,1,2,3,4,5,6,7])
#Load 256x256 tiny-imagenet trained ResNet18
pretrained_dict = torch.load('/disk/LHC/models/nips_resnet18_64_noise_pretrained_256_with_val_noise_weight.pt')
model_ft_dict = model_ft.state_dict()

#Reset 1st layer weight
first_layer_weight = model_ft_dict['module.conv1.weight']
first_layer_bias  = model_ft_dict['module.conv1.bias']
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_ft_dict}

model_ft_dict.update(pretrained_dict) 
model_ft_dict['module.conv1.weight'] = first_layer_weight
model_ft_dict['module.conv1.bias']   = first_layer_bias

#Load pretrained weight from layer 2~18
model_ft.load_state_dict(model_ft_dict)


#Loss Function
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.1)

In [None]:
#Train
model_ft = train_model(model_ft, dataloaders, dataset_sizes, criterion, optimizer_ft, exp_lr_scheduler,
                       batch_size=batch_size, num_epochs=5)

In [None]:
torch.save(model_ft, "./models/nips_resnet18_71_noise_aug_pretrained_64_noise_bt2000.pt") 
torch.save(model_ft.state_dict(), "./models/nips_resnet18_71_noise_aug_pretrained_64_noise_bt2000_weight.pt") 