In [1]:
import sys
sys.path.insert(1, '../src/')
import torch
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets,transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import vector_to_parameters, parameters_to_vector
import matplotlib.pyplot as plt
import inference as infer
import time
import os
import copy
import utils
from torchsummary import summary
import models
import random
import torch.nn.functional as F
import shutil
from torchsummary import summary

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [3]:
class args:
    data='mnist'
    bs = 128
    device = 'cuda:0'
    lr = 0.1
    moment = 0.9
    wd = 5e-4
    epochs=150
    nesterov = True
    base_class = 0 # airplane
    target_class = 2 # bird 
    corruption_frac = 0.05

In [4]:
train_dataset, val_dataset = utils.get_datasets(args.data)
train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=2, pin_memory=False)
val_loader =  DataLoader(val_dataset, batch_size=args.bs, shuffle=False, num_workers=2, pin_memory=False)

In [5]:
model = models.get_model(args.data).to(args.device)
#model = nn.DataParallel(model)
parameters_bias = [p[1] for p in model.named_parameters() if 'bias' in p[0]]
parameters_scale = [p[1] for p in model.named_parameters() if 'scale' in p[0]]
parameters_others = [p[1] for p in model.named_parameters() if not ('bias' in p[0] or 'scale' in p[0])]
optimizer = optim.SGD([{'params': parameters_bias, 'lr': args.lr/10.},\
                    {'params': parameters_scale, 'lr': args.lr/10.},\
                    {'params': parameters_others}], lr=args.lr,\
                      momentum=0.9, weight_decay=args.wd,nesterov=args.nesterov)
criterion = nn.CrossEntropyLoss().to(args.device)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, factor=0.1, verbose=True)
#summary(model, (3, 32, 32))

In [6]:
server_lr = nn.Parameter(torch.Tensor([1]), requires_grad=True)

optimizer2 = optim.SGD([server_lr], lr=1)
scheduler2 = optim.lr_scheduler.ReduceLROnPlateau(optimizer2, patience=1, factor=0.1, verbose=True)

In [12]:
optimizer2.param_groups

[{'params': [Parameter containing:
   tensor([1.], requires_grad=True)],
  'lr': 0.010000000000000002,
  'momentum': 0,
  'dampening': 0,
  'weight_decay': 0,
  'nesterov': False}]

In [7]:
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=epochs*len(train_loader), max_lr=lr,\
             #pct_start=0.25, anneal_strategy='linear', div_factor=10.0)
#lrs, losses = infer.find_lr(model, optimizer, train_loader, args.device)
#plt.plot(lrs, losses)

In [8]:
shutil.rmtree('../logs/') 
writer = SummaryWriter('../logs/cifar10')
start_time, end_time = torch.cuda.Event(enable_timing=True),\
                        torch.cuda.Event(enable_timing=True)
start_time.record()

In [9]:
for rnd in tqdm(range(1, args.epochs+1)):
    model.train()
    train_loss, train_acc = 0.0, 0.0 
    for _, (inputs, labels) in enumerate(train_loader):
        # pass inputs to device, clear gradients
        inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)
        optimizer.zero_grad()
        
        # forward-backward pass and update
        outputs = model(inputs)
        minibatch_loss = criterion(outputs, labels)
        minibatch_loss.backward()
        optimizer.step()
        #scheduler.step()
        
        # keep track of round loss/accuracy
        train_loss += minibatch_loss.item()*outputs.shape[0]
        _, pred_labels = torch.max(outputs, 1)
        train_acc += torch.sum(torch.eq(pred_labels.view(-1), labels)).item()
        
     # inference after round 
    train_loss, train_acc = train_loss/len(train_dataset), train_acc/len(train_dataset)       
    val_loss, (val_acc, val_per_class) = infer.get_loss_n_accuracy(model, criterion, val_loader, args)                                  
    #scheduler.step(val_loss)
    scheduler2.step(val_loss)
    print(server_lr)
    # log/print data
    writer.add_scalar('Validation/Loss', val_loss, rnd)
    writer.add_scalar('Validation/Accuracy', val_acc, rnd)
    writer.add_scalar('Training/Loss', train_loss, rnd)
    writer.add_scalar('Training/Accuracy', train_acc, rnd)
    print(f'|Train/Valid Loss: {train_loss:.3f} / {val_loss:.3f}|', end='--')
    print(f'|Train/Valid Acc: {train_acc:.3f} / {val_acc:.3f}|', end='\r')

HBox(children=(FloatProgress(value=0.0, max=150.0), HTML(value='')))

Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:304 / 0.061|--|Train/Valid Acc: 0.905 / 0.982|
tensor([1.], requires_grad=True)
Parameter containing:157 / 0.061|--|Train/Valid Acc: 0.954 / 0.982|
tensor([1.], requires_grad=True)
Parameter containing:134 / 0.045|--|Train/Valid Acc: 0.960 / 0.985|
tensor([1.], requires_grad=True)
Parameter containing:115 / 0.044|--|Train/Valid Acc: 0.966 / 0.987|
tensor([1.], requires_grad=True)
Parameter containing:107 / 0.041|--|Train/Valid Acc: 0.968 / 0.987|
tensor([1.], requires_grad=True)
Parameter containing:101 / 0.039|--|Train/Valid Acc: 0.970 / 0.987|
tensor([1.], requires_grad=True)
Parameter containing:100 / 0.039|--|Train/Valid Acc: 0.971 / 0.988|
tensor([1.], requires_grad=True)
Parameter containing:087 / 0.034|--|Train/Valid Acc: 0.975 / 0.988|
tensor([1.], requires_grad=True)
Parameter containing:088 / 0.033|--|Train/Valid Acc: 0.974 / 0.991|
tensor([1.], requires_grad=True)
Parameter containing:089 / 0.032|--|

KeyboardInterrupt: 

In [None]:
end_time.record()
torch.cuda.synchronize()
time_elapsed_secs = start_time.elapsed_time(end_time)/10**3
time_elapsed_mins = time_elapsed_secs/60
print(f'Training took {time_elapsed_secs:.2f} seconds / {time_elapsed_mins:.2f} minutes')
