In [1]:
import os
from tqdm import tqdm
import json
import argparse
import pickle
import torch
import torch.nn as nn
import numpy as np
import pprint
from tensorboardX import SummaryWriter
import re
import gc
import sys
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances
from numpy.linalg import svd
from mpl_toolkits.mplot3d import Axes3D
import scipy

sys.path = ['..'] + sys.path
from algorithm_trainer.models import gated_conv_net_original, resnet, resnet_2, resnet_12, wide_resnet
from algorithm_trainer.algorithm_trainer import Generic_adaptation_trainer, Classical_algorithm_trainer
from algorithm_trainer.algorithms.algorithm import SVM, ProtoNet, Finetune, ProtoCosineNet, ProtoCosineNetCorrected2
from algorithm_trainer.utils import accuracy
from data_layer.dataset_managers import MetaDataManager, ClassicalDataManager
from analysis.objectives import var_reduction_disc, var_reduction_disc_perp, var_reduction

%matplotlib inline

### load tl and ml checkpoints

In [2]:
def load_model(checkpoint, lambd):
    
    model_dc = resnet_12.resnet12(
    avg_pool=True, drop_rate=0.1, dropblock_size=5,
    classifier_type='avg-classifier', num_classes=64)
    
    print(f"loading from {checkpoint}")
    model_dict = model_dc.state_dict()
    chkpt_state_dict = torch.load(checkpoint)
    if 'model' in chkpt_state_dict:
        chkpt_state_dict = chkpt_state_dict['model']
    chkpt_state_dict_cpy = chkpt_state_dict.copy()
    # remove "module." from key, possibly present as it was dumped by data-parallel
    for key in chkpt_state_dict_cpy.keys():
        if 'module.' in key:
            new_key = re.sub('module\.', '',  key)
            chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
    chkpt_state_dict = {k: v for k, v in chkpt_state_dict.items() if k in model_dict}
    model_dict.update(chkpt_state_dict)
    updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
    print(f"Updated {len(updated_keys)} keys using chkpt")
    print("Following keys updated :", "\n".join(sorted(updated_keys)))
    missed_keys = set(model_dict).difference(set(chkpt_state_dict))
    print(f"Missed {len(missed_keys)} keys")
    print("Following keys missed :", "\n".join(sorted(missed_keys)))
    model_dc.load_state_dict(model_dict)

#     model_dc.scale = torch.nn.Parameter(torch.tensor([1.0]))        
    model_dc = torch.nn.DataParallel(model_dc, device_ids=range(torch.cuda.device_count()))
    model_dc.cuda()
    model_dc.eval()

    # set lambd 
    model_dc.module.fc.lambd = lambd
    return model_dc

In [5]:
checkpoint_ml = '../train_dir/minmax_MI_r12_n64_s3_q128/classical_resnet_120.pt'
# checkpoint_tl = '../train_dir/classical_miniimagenet_r12_classifier_bs128_parametric_tfl_gradanalysis/classical_resnet_399.pt'
model_ml = load_model(checkpoint_ml, 0.)
# model_tl = load_model(checkpoint_tl, 1.)

loading from ../train_dir/minmax_MI_r12_n64_s3_q128/classical_resnet_120.pt
Updated 98 keys using chkpt
Following keys updated : fc.scale_factor
layer1.0.bn1.bias
layer1.0.bn1.num_batches_tracked
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.weight
layer1.0.bn2.bias
layer1.0.bn2.num_batches_tracked
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.weight
layer1.0.bn3.bias
layer1.0.bn3.num_batches_tracked
layer1.0.bn3.running_mean
layer1.0.bn3.running_var
layer1.0.bn3.weight
layer1.0.conv1.weight
layer1.0.conv2.weight
layer1.0.conv3.weight
layer1.0.downsample.0.weight
layer1.0.downsample.1.bias
layer1.0.downsample.1.num_batches_tracked
layer1.0.downsample.1.running_mean
layer1.0.downsample.1.running_var
layer1.0.downsample.1.weight
layer2.0.bn1.bias
layer2.0.bn1.num_batches_tracked
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.weight
layer2.0.bn2.bias
layer2.0.bn2.num_batches_tracked
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
lay

### interpolate b/w tl and ml checkpoints

In [6]:
def func(x, n_param, model):
    
    print("Accessing function value")
    
    
    # dataloaders
    image_size = 84
    dataset_path = '../data/filelists/miniImagenet'
    val_file = os.path.join(dataset_path, 'val.json')
    train_file = os.path.join(dataset_path, 'base.json')
    classical_val_datamgr = ClassicalDataManager(image_size, batch_size=128)
    classical_val_loader = classical_val_datamgr.get_data_loader(val_file, aug=False)
    classical_train_datamgr = ClassicalDataManager(image_size, batch_size=128)
    classical_train_loader = classical_val_datamgr.get_data_loader(train_file, aug=True)
    aux_datamgr = ClassicalDataManager(image_size, batch_size=128)
    aux_loader = aux_datamgr.get_data_loader(train_file, aug=False)
    iterator = tqdm(enumerate(classical_train_loader, start=1),
                        leave=False, file=sys.stdout, position=0)
    aux_iterator = iter(aux_loader)
    
    
    # compute f, grad
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    gradient = np.zeros(n_param)
    n_batches = 0
    total_loss = 0.
    curr = 0
    for param in model.parameters():
        if param.requires_grad:
            param_len = len(param.flatten())
            param.data = torch.tensor(
                x[curr:curr+param_len].astype(np.float32)).cuda().reshape(param.shape)
            curr += param_len
    
    for i, batch in iterator:
        
        # update L
        aux_batch_x, aux_batch_y = next(aux_iterator)
        aux_batch_x = aux_batch_x.cuda()
        aux_batch_y = aux_batch_y.cuda()
        aux_features_x = model(aux_batch_x, features_only=True)
#         print("aux_features_x", aux_features_x.shape)
        model.module.fc.update_L(aux_features_x, aux_batch_y)
           
        # loss gradient 
        batch_x, batch_y = batch
        batch_x = batch_x.cuda()
        batch_y = batch_y.cuda()
        features_x = model(batch_x, features_only=True)
#         print("features_x", features_x.shape)
        output_x = model.module.fc(features_x)
        loss = -loss_func(output_x, batch_y)   
#         print(output_x, batch_y)
        total_loss += loss.item()
        optimizer.zero_grad()   
        loss.backward()

        # statitics
        curr = 0
        for param in model.parameters():
            if param.requires_grad:
                param_len = len(param.flatten())
                grad = param.grad.flatten().cpu().numpy()
                assert grad.shape == param.flatten().shape
                gradient[curr:curr+param_len] += grad
                curr += param_len
        n_batches += 1
        
    gradient /= n_batches
    total_loss /= n_batches
    print(f"loss: {total_loss}")
    return total_loss, gradient

In [7]:
def compute_sharpness(model, eps=0.0005):
    
    n_param = 0
    for param in model.parameters():
        if param.requires_grad:
            n_param += len(param.flatten()) 
    print(f"no of parameters: {n_param}")
    
    # x0
    x0 = np.zeros(n_param)
    curr = 0
    for param in model.parameters():
        if param.requires_grad:
            param_len = len(param.flatten())
            x0[curr:curr+param_len] = param.flatten().detach().cpu().numpy()
            curr += param_len
    
    # bounds
    bounds = np.zeros((n_param, 2))
    bounds[:, 0] = -1e9
    bounds[:, 1] = 1e9
    curr = 0
    for param in model.parameters():
        if param.requires_grad:
            param_len = len(param.flatten())
            val0 = x0[curr:curr+param_len]
            bounds[curr:curr+param_len, 0] = val0 - eps * (np.abs(val0) + 1)
            bounds[curr:curr+param_len, 1] = val0 + eps * (np.abs(val0) + 1)
            curr += param_len
    
    # f_x0
    f_x0, _ = func(x0, n_param, model)
    f_x0 = -f_x0
    
    # L-BFGS-B optimizer
    x_optimal, y_opt, opt_info = scipy.optimize.fmin_l_bfgs_b(
        func=func, x0=x0, args=(n_param, model), bounds=bounds)
    sharpness = (-y_opt - f_x0) * 100. / (1 + f_x0) 
    print(f"sharpness : {sharpness}")
    print(x_optimal, y_opt, opt_info)

In [9]:
compute_sharpness(model_ml)

no of parameters: 12465283
Accessing function value
loss: -0.012343934960663319
Accessing function value
loss: -0.012672632249693077
Accessing function value
loss: -1.1614664324124655
Accessing function value
loss: -6.457420837084452
Accessing function value
loss: -6.537908771832784
Accessing function value
loss: -6.542549200057984
Accessing function value
loss: -6.678083092371622
Accessing function value
loss: -6.64824382464091 
Accessing function value
loss: -6.667147625287374
Accessing function value
loss: -6.604393127759297
Accessing function value
loss: -6.5696873299280805
Accessing function value
loss: -6.575246748924255
Accessing function value
loss: -6.565065883000692
Accessing function value
loss: -6.585975462595622
sharpness : 649.3476476342391
[ 1.00000000e+00 -1.80262337e-05 -5.32375868e-05 ...  8.35407682e-02
 -1.17586164e-03 -1.08974662e-02] -6.585975462595622 {'grad': array([ 0.00000000e+00,  2.82697728e-08,  2.73237390e-08, ...,
       -1.77062696e-04, -7.49200914e-07, 

In [None]:
no of parameters: 12465283
Accessing function value
loss: -0.010902153588831424
Accessing function value
loss: -0.010902153626084328
Accessing function value
loss: -2.16028910279274 
Accessing function value
loss: -29.300055592854818
Accessing function value
loss: -35.17557498931885
Accessing function value
loss: -39.13319180806478
Accessing function value
loss: -40.946407178243  
Accessing function value
loss: -42.0526579284668 
Accessing function value
loss: -42.95348759969075
Accessing function value
loss: -44.69223894755046
Accessing function value
loss: -45.28219554901123
Accessing function value
loss: -45.68622080485026
Accessing function value
loss: -46.00641686757405
Accessing function value
loss: -46.649851417541505
Accessing function value
loss: -47.029435335795085
Accessing function value
loss: -47.5339617284139 
Accessing function value
loss: -47.75037869771322
Accessing function value
loss: -47.889085896809895
Accessing function value
loss: -48.011613032023114
Accessing function value
loss: -48.76787291208903
Accessing function value
loss: -49.161722094217936
Accessing function value
loss: -49.52843022664388
Accessing function value
loss: -50.08370721181234
Accessing function value
loss: -50.514968897501625
Accessing function value

In [None]:
# ml
no of parameters: 12465283
Accessing function value
loss: -0.012894223866363366
Accessing function value
loss: -0.012551092219849428
Accessing function value
loss: -3.9413750688234965
Accessing function value
loss: -6.741963555018107
Accessing function value
loss: -6.796368522644043
Accessing function value
loss: -6.719416332244873
Accessing function value
loss: -6.719123921394348
Accessing function value
loss: -6.73290487130483 
Accessing function value
loss: -6.703963958422343
Accessing function value
loss: -6.8107650820414225
Accessing function value
loss: -6.808843242327372
Accessing function value
loss: -6.690572388966879
sharpness : 659.2670792030835

In [None]:
# tl

In [None]:
66 /1.01

In [None]:
x_true = np.arange(0,10,0.1)
m_true = 2.5
b_true = 1.0
y_true = m_true*x_true + b_true

def func(params, *args):
    print("calling f")
    x = args[0]
    y = args[1]
    m, b = params
    y_model = m*x+b
    error = y-y_model
    return sum(error**2)

initial_values = np.array([1.0, 0.0])
mybounds = [(None,2), (None,None)]

scipy.optimize.fmin_l_bfgs_b(func, x0=initial_values, args=(x_true,y_true), approx_grad=True)
scipy.optimize.fmin_l_bfgs_b(func, x0=initial_values, args=(x_true, y_true), bounds=mybounds, approx_grad=True)