In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import sys
sys.path.append('/home/hrushikesh/')
from torch_functions import *
from vgg_batchout_1 import VGG

args = {'model': '/content/model_reg_156.pt', 'start_epoch': 156}
print(args)

batch_size = 128

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

aug = transforms.RandomChoice((transforms.RandomHorizontalFlip(p=1), transforms.RandomCrop(32, padding=4),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0)),
    transforms.RandomAffine(degrees=0, translate=(0, 0.1))))

train_transforms = transforms.Compose([aug, transforms.ToTensor(), transforms.Normalize(cifar10_mean, cifar10_std)])
test_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(cifar10_mean, cifar10_std)])

cifar_train = datasets.CIFAR10("/home/hrushikesh/torch/data", train=True, download=True, transform=train_transforms)
cifar_test = datasets.CIFAR10("/home/hrushikesh/torch/data", train=False, download=True, transform=test_transforms)

train_loader = DataLoader(cifar_train, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size = batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
print(device)

model = VGG(0.2)
model.to(device)

opt = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay = 5e-4)

# We need to initalize k=0 for correct plotting
k = 0

if args['model']:
    model.load_state_dict(torch.load(args['model']))
    print("Model loaded.. here we go")




{'model': '/content/model_reg_156.pt', 'start_epoch': 156}
Files already downloaded and verified
Files already downloaded and verified
cuda
Model loaded.. here we go


In [21]:
print('Started Training')


for i in range(args['start_epoch'] + 1, args['start_epoch'] + 36):

    acc, loss = epoch_perturbation(train_loader, model, device, opt, reg=False)
    test_acc, test_loss = epoch_perturbation(test_loader, model, device)


    plot_fig(acc, loss, test_acc, test_loss, startAt = args['start_epoch'], path='./progress_reg.png', jsonPath = './progress_reg.json', k=k)

    print("Epoch number:{}".format(i), *("{:.3f}".format(j) for j in (acc, test_acc, loss, test_loss)), sep="\t")    
    if (i) % 3 == 0:
        torch.save(model.state_dict(), "model_reg_{}.pt".format(i))
        # The saved model filename says the model is saved after completing those many epochs. For ex model_5.pt says 5 epochs are done
        # Now pass args['start_epoch'] = 5 (not 6) to continue the training.  
        print("Model saved")


    # This is required to ensure the JSon file is not rewritten
    if not k == 1:
        k =1



Started Training
Epoch number:157	0.968	0.866	0.108	0.519
Epoch number:158	0.978	0.871	0.075	0.512
Epoch number:159	0.981	0.870	0.064	0.517
Model saved
Epoch number:160	0.982	0.870	0.058	0.504
Epoch number:161	0.985	0.870	0.051	0.502
Epoch number:162	0.987	0.872	0.045	0.516
Model saved
Epoch number:163	0.988	0.871	0.043	0.501
Epoch number:164	0.989	0.871	0.037	0.517
Epoch number:165	0.990	0.873	0.036	0.506
Model saved
Epoch number:166	0.989	0.871	0.036	0.500
Epoch number:167	0.991	0.869	0.033	0.515
Epoch number:168	0.991	0.874	0.032	0.513
Model saved
Epoch number:169	0.992	0.875	0.027	0.517
Epoch number:170	0.992	0.871	0.028	0.538
Epoch number:171	0.992	0.876	0.027	0.538
Model saved
Epoch number:172	0.993	0.875	0.026	0.524
Epoch number:173	0.992	0.869	0.026	0.524
Epoch number:174	0.994	0.879	0.023	0.527
Model saved
Epoch number:175	0.994	0.875	0.023	0.538
Epoch number:176	0.994	0.872	0.022	0.551
Epoch number:177	0.994	0.868	0.022	0.592
Model saved
Epoch number:178	0.995	0.875	0.020	0.5

KeyboardInterrupt: ignored