In [None]:

# gradient statistic -> over full dataset
# For each weight -> mean and std.

import os
import sys 
os.chdir(os.path.join(os.getenv('HOME'), 'ASL'))
sys.path.insert(0, os.path.join(os.getenv('HOME'), 'ASL'))
sys.path.append(os.path.join(os.path.join(os.getenv('HOME'), 'ASL') + '/src'))

from torchvision import transforms as tf
import numpy as np
import pickle
import torch
import torch.nn.functional as F
from utils_asl import file_path, load_yaml
from models_asl import FastSCNN

In [None]:
# SETUP MODEL

eval_cfg_path="cfg/eval/eval.yml"
env_cfg_path = os.path.join('cfg/env', os.environ['ENV_WORKSTATION_NAME']+ '.yml')
env_cfg = load_yaml(env_cfg_path)	
eval_cfg = load_yaml(eval_cfg_path)
device= "cuda"
model = FastSCNN(**eval_cfg['model']['cfg'])


p = os.path.join( env_cfg['base'], eval_cfg['checkpoint_load'])
print(p)

def load(model, p):
    if os.path.isfile( p ):
        res = torch.load(p)
        new_statedict = {}
        for k in res['state_dict'].keys():
            if k.find('model.') != -1: 
                new_statedict[ k[6:]] = res['state_dict'][k]
        res = model.load_state_dict( new_statedict, strict=True)
        print('Restoring weights: ' + str(res))
    else:
        raise Exception('Checkpoint not a file')
    return model
model = load(model,p)

model.to(device)

In [None]:
from datasets_asl import get_dataset
import torch
from torchvision import transforms as tf

# SETUP DATALOADER
dataset_test = get_dataset(
  **eval_cfg['dataset'],
  env = env_cfg,
  output_trafo = None
  )

dataloader_test = torch.utils.data.DataLoader(dataset_test,
  shuffle = False,
  num_workers = 0,
  pin_memory = True,
  batch_size = 2, 
  drop_last = True)

globale_idx_to_image_path = dataset_test.image_pths
from visu import Visualizer
visu = Visualizer(os.getenv('HOME')+'/tmp', logger=None, epoch=0, store=False, num_classes=41)

In [None]:
from task import TaskCreator
from datasets_asl import get_dataloader_train, eval_lists_into_dataloaders

def get_mutiple_dataloaders(exp, env, train=True):
    tc = TaskCreator(**exp['task_generator'],output_size=exp['dataset']['output_size'])
    ret_list = []
    for t in tc:
        task, eval_lists = t
        dataloader_train, dataloader_buffer= get_dataloader_train(d_train= task.dataset_train_cfg,
                                                              env=env,exp = exp)
        if train: 
            ret_list.append( dataloader_train )
        else: 
            # RETURNS ALL VALIDATION DATALOADERS
            return eval_lists_into_dataloaders(eval_lists, env=env, exp=exp)
    return ret_list
        
dataloader_list = get_mutiple_dataloaders(exp = eval_cfg , env = env_cfg, train=False )

In [None]:
model_paths = [ '/media/scratch1/jonfrey/models/cluster/2021-03-19T19:20:20_check_val_every_2_random/task0-epoch=29--step=016259.ckpt',
'/media/scratch1/jonfrey/models/cluster/2021-03-19T19:20:20_check_val_every_2_random/task1-epoch=63--step=030846.ckpt',
'/media/scratch1/jonfrey/models/cluster/2021-03-19T19:20:20_check_val_every_2_random/task2-epoch=95--step=045503.ckpt',
'/media/scratch1/jonfrey/models/cluster/2021-03-19T19:20:20_check_val_every_2_random/task3-epoch=135--step=060290.ckpt']
print( model_paths )
for f in dataloader_list:
    print(f.dataset)

In [None]:
# TESTED AND WORKS
import copy

def get_grad( named_params ):
    summary_grad = []
    for i, (n,p) in enumerate( named_params ):
        if p.grad is not None:  
            summary_grad.append( p.grad.view(-1).detach()  )
    summary_grad = torch.cat( summary_grad )
    return summary_grad

def write_back_grad( grad, named_params ):
    count = 0
            
    for i, (n,p) in enumerate( named_params ):
        if p.grad is not None:
            s = p.grad.shape
            c = p.grad.view(-1).shape[0]
            new_grad = grad[count: (count+c) ].contiguous().view(s)
            p.grad.data.copy_(new_grad)  
            count +=c

def project(g: torch.Tensor, g_ref: torch.Tensor) -> torch.Tensor:
    
    corr = torch.dot(g , g_ref) / torch.dot(g_ref, g_ref)
    return g - corr * g_ref


In [None]:
augmentation = True

trafo = tf.Compose([
      tf.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
      tf.Normalize([.485, .456, .406], [.229, .224, .225])
])

model = load(model, model_paths[0])


In [None]:
for j, batch in enumerate( dataloader_test ):
    # START EVALUATION  
    images = (batch[0]).cuda()
    target = batch[1].cuda()
    ori_img = batch[2]
    replayed = batch[3]
    BS = images.shape[0]
    global_idx = batch[4] 
    
    if augmentation:
                images = trafo(images)
    if j > 1:
        break
        
    if j == 0: 
            # normal forward pass and store gradient
            
            ret = model(images)
            loss = F.cross_entropy(ret[0], target, ignore_index=-1)
            loss.backward()
            g_ref = get_grad( model.named_parameters() )
    if j == 1:
            ret = model(images)
            loss = F.cross_entropy(ret[0], target, ignore_index=-1)
            loss.backward()
            g = get_grad( model.named_parameters() )
    
    model.zero_grad()
    
g_tilde = project( g, g_ref)
write_back_grad( g_tilde, model.named_parameters() )

In [None]:
class Test():
    def __init__(self):
        super().__init__()
        
        print("init")
        self.b = 2
    def test(self):
        print ( hasattr(self, 'b') )
        print ( hasattr(self, 'a') )
        a = self.b
        print (locals())
        print( "a_variable" in locals() )
tei = Test()
tei.test()


In [None]:
gradient_buffer