In [1]:
import sys
sys.path.insert(1, '../src/')
import torch
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets,transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import vector_to_parameters, parameters_to_vector
import matplotlib.pyplot as plt
import time
import os
import copy
import utils
from torchsummary import summary
import models
import random
import torch.nn.functional as F
import shutil
import copy
from torch import autograd
from torchviz import make_dot, make_dot_from_trace
import higher

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [3]:
class args:
    data='cifar10'
    bs = 512
    device = 'cuda:0'
    lr = 1e-1
    wd = 5e-4
    epochs=100
    nesterov=True
    moment=0.9

In [4]:
train_dataset, val_dataset = utils.get_datasets(args.data)
meta_dataset = copy.deepcopy(train_dataset)
meta_dataset.data = meta_dataset.data[:512]
meta_dataset.targets = meta_dataset.targets[:512]

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=2, pin_memory=True)
meta_loader =  DataLoader(meta_dataset, batch_size=args.bs, shuffle=True)
val_loader =  DataLoader(val_dataset, batch_size=args.bs, shuffle=False, num_workers=2, pin_memory=True)

In [6]:
_, (meta_inputs, meta_labels) =  next(enumerate(meta_loader))
meta_inputs, meta_labels = meta_inputs.to(device=args.device, non_blocking=True),\
                             meta_labels.to(device=args.device, non_blocking=True)
del meta_loader

In [7]:
model = models.get_model(args.data).to(args.device)
parameters_bias = [p[1] for p in model.named_parameters() if 'bias' in p[0]]
parameters_scale = [p[1] for p in model.named_parameters() if 'scale' in p[0]]
parameters_others = [p[1] for p in model.named_parameters() if not ('bias' in p[0] or 'scale' in p[0])]
opt = optim.SGD([{'params': parameters_bias, 'lr': args.lr/10.},\
                    {'params': parameters_scale, 'lr': args.lr/10.},\
                    {'params': parameters_others}], lr=args.lr,\
                    momentum=args.moment, weight_decay=args.wd,nesterov=args.nesterov)

criterion = nn.CrossEntropyLoss().to(args.device)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=5, verbose=True)
#summary(model, (3, 32, 32))

In [8]:
start_time, end_time = torch.cuda.Event(enable_timing=True),\
                        torch.cuda.Event(enable_timing=True)

writer = SummaryWriter('logs/centralized')
start_time.record()

In [9]:
for rnd in tqdm(range(1, args.epochs+1)):
    model.train()
    train_loss, train_acc = 0.0, 0.0 
    
    for _, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device=args.device, non_blocking=True),\
                        labels.to(device=args.device, non_blocking=True)
        opt.zero_grad()
        
        with higher.innerloop_ctx(model, opt) as (meta_model, meta_opt):
            # 1. Update meta model on training data
            meta_train_outputs = meta_model(inputs)
            criterion.reduction = 'none'
            meta_train_loss = criterion(meta_train_outputs, labels)
            eps = torch.rand(meta_train_loss.size(), requires_grad=False, device=args.device).div(1e6)
            eps.requires_grad = True
            meta_train_loss = torch.sum(eps * meta_train_loss)
            meta_opt.step(meta_train_loss)
            
            # 2. Compute grads of eps on meta validation data
            meta_val_outputs = meta_model(meta_inputs)
            criterion.reduction = 'mean'
            meta_val_loss = criterion(meta_val_outputs, meta_labels)
            eps_grads = torch.autograd.grad(meta_val_loss, eps, only_inputs=True)[0].detach()
            
        # 3. Compute weights for current training batch
        w_tilde = torch.clamp(-eps_grads, min=0)
        l1_norm = torch.sum(w_tilde)
        if l1_norm != 0:
            w = w_tilde / l1_norm
        else:
            w = w_tilde
            
        # 4. Train model on weighted batch
        outputs = model(inputs)
        criterion.reduction = 'none'
        minibatch_loss = criterion(outputs, labels)
        minibatch_loss = torch.sum(w * minibatch_loss)
        minibatch_loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        opt.step()
        
        # keep track of epoch loss/accuracy
        train_loss += minibatch_loss.item()*outputs.shape[0]
        _, pred_labels = torch.max(outputs, 1)
        train_acc += torch.sum(torch.eq(pred_labels.view(-1), labels)).item()
        
    # inference after epoch
    with torch.no_grad():
        train_loss, train_acc = train_loss/len(train_dataset), train_acc/len(train_dataset)       
        val_loss, (val_acc, val_per_class) = utils.get_loss_n_accuracy(model, criterion, val_loader, args)                                  
        scheduler.step(val_loss)
        # log/print data
        #writer.add_scalar('Validation/Loss', val_loss, rnd)
        #writer.add_scalar('Validation/Accuracy', val_acc, rnd)
        #writer.add_scalar('Training/Loss', train_loss, rnd)
        #writer.add_scalar('Training/Accuracy', train_acc, rnd)
        print(f'|Train/Valid Loss: {train_loss:.3f} / {val_loss:.3f}|', end='--')
        print(f'|Train/Valid Acc: {train_acc:.3f} / {val_acc:.3f}|', end='\r')

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

|Train/Valid Loss: 2.025 / 1.873|--|Train/Valid Acc: 0.256 / 0.312|

KeyboardInterrupt: 

In [None]:
end_time.record()
torch.cuda.synchronize()
time_elapsed_secs = start_time.elapsed_time(end_time)/10**3
time_elapsed_mins = time_elapsed_secs/60
print(f'Training took {time_elapsed_secs:.2f} seconds / {time_elapsed_mins:.2f} minutes')