In [None]:
# So we can load files from other sub-directories, e.g. datasets.
import os
import sys
module_path = os.path.abspath(os.path.join('unlabeled_extrapolation'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from unlabeled_extrapolation.models import imnet_resnet
from unlabeled_extrapolation.models import imnet_models
from torch import nn
import torch
import torchvision
import numpy as np
import glob
import pickle
import unlabeled_extrapolation.utils.utils as utils
import json

from unlabeled_extrapolation.datasets import fmow
from unlabeled_extrapolation.baseline_train import preprocess_config
from unlabeled_extrapolation.extract_features import get_features_labels

import importlib
importlib.reload(imnet_resnet)
importlib.reload(imnet_models)


In [None]:
import torch
torch.tensor(3.0).tolist()

# Get weight distance between resnets (e.g. full ft vs random vs lpft)

In [None]:
def get_param_weights_counts(net):
    weight_dict = {}
    count_dict = {}
    for param in net.named_parameters():
        name = param[0]
        weights = param[1]
        weight_dict[name] = weights
        count_dict[name] = np.prod(np.array(list(param[1].shape)))
    return weight_dict, count_dict

def get_l2_dist(weight_dict1, weight_dict2, count_dict, ignore='.fc.'):
    l2_dist = torch.tensor(0.0).cuda()
    count = 0.0
    for key in weight_dict1:
        if ignore not in key:
            l2_dist += torch.sum(torch.square(weight_dict1[key] - weight_dict2[key]))
            count += count_dict[key]
    l2_dist = l2_dist / count
    return float(l2_dist.cpu().detach().numpy())
            
def load_model(config_path, checkpoint_path):
    print(config_path)
    with open(config_path) as f:
        config = json.load(f)
    net = imnet_resnet.ResNet50()
    net.new_last_layer(config['num_classes'])
    utils.load_ckp(checkpoint_path, net)
    net = net.cuda()
    return net

# Load datasets.
def load_test_dataset(config_path, idx, batch_size=64, num_workers=2):
    with open(config_path) as f:
        config = json.load(f)
    preprocess_config(config, config_path)
    test_config = config['test_datasets'][idx]
    if 'unlabeled_extrapolation' not in test_config['classname']:
        test_config['classname'] = 'unlabeled_extrapolation.' + test_config['classname']
#     if 'transforms' not in test_config:
#         test_config['transforms'] = config['default_test_transforms']
    test_data = utils.init_dataset(test_config)
    test_loader = torch.utils.data.DataLoader(
        test_data, batch_size=batch_size,
        shuffle=False, num_workers=num_workers)
    return test_data, test_loader

In [None]:
mocov2_model = imnet_resnet.ResNet50(pretrained=True, pretrain_style='mocov2', checkpoint_path='/u/scr/ananya/simclr_weights/moco_v2_800ep_pretrain.pth.tar')
mocov2_model.cuda()

In [60]:
# Compare weight distance between checkpoints

def load_ckp(path, num_classes):
    net = imnet_resnet.ResNet50()
    net.new_last_layer(num_classes)
    utils.load_ckp(path, net)
    return net

model_path_fmts = {
    'ft_living17_best_path_fmt': '../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-{}_run{}/',
    'ft_living17_samelr_path_fmt': '../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-{}_run{}/',
    'lpft_living17_path_fmt': '../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-{}_use_net_val_mode-True_run{}/',
}
checkpoint_ext = 'checkpoints/ckp_best_source_val_living'
config_ext = 'config.json'
num_classes = 17
distances = {
    'ft_living17_best_path_fmt': [],
    'ft_living17_samelr_path_fmt': [],
    'lpft_living17_path_fmt': [],

}
wd, cd = get_param_weights_counts(mocov2_model)

for i in range(3):
    for key in model_path_fmts:
        ckp_path = model_path_fmts[key].format(i, i) + checkpoint_ext
        config_path = model_path_fmts[key].format(i, i) + config_ext
        net = load_model(config_path, ckp_path)
        dist = get_l2_dist(wd, get_param_weights_counts(net)[0], cd)
        distances[key].append(dist)

../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-0_run0/config.json
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-0_run0/config.json
../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-0_use_net_val_mode-True_run0/config.json
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-1_run1/config.json
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-1_run1/config.json
../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-1_use_net_val_mode-True_run1/config.json
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-2_run2/config.json
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-2_run2/config.json
../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-2_use_net_val_mode-True_run2/config.json


In [61]:
for key in distances:
    print(key)
    print(np.mean(distances[key]))
    print(np.std(distances[key]) / np.sqrt(len(distances[key])) * 1.645)

ft_living17_best_path_fmt
2.2578324963736427e-06
6.152266444806382e-09
ft_living17_samelr_path_fmt
2.4158858498897945e-07
2.8288501226835223e-09
lpft_living17_path_fmt
6.773888827638075e-08
1.2904584535951586e-09


In [79]:
def add_features(id_val_loader, ood_val_loader, ood_features_dict, id_features_dict, name, net):
    id_features, _ = get_features_labels(net, id_val_loader)
    ood_features, _ = get_features_labels(net, ood_val_loader)
    id_features_dict[name] = id_features
    ood_features_dict[name] = ood_features

id_dists, ood_dists = {}, {}
for key in model_path_fmts:
    id_dists[key] = []
    ood_dists[key] = []
for i in range(3):
    ood_features, id_features = {}, {}
    for key in model_path_fmts:
        ckp_path = model_path_fmts[key].format(i, i) + checkpoint_ext
        config_path = model_path_fmts[key].format(i, i) + config_ext
        net = load_model(config_path, ckp_path)
        print('loading data')
        _, id_val_loader = load_test_dataset(config_path, idx=0)
        _, ood_val_loader = load_test_dataset(config_path, idx=1)
        print('computing features')
        if 'mocov2' not in id_features:
            add_features(id_val_loader, ood_val_loader, ood_features, id_features, name='mocov2', net=mocov2_model)
        add_features(id_val_loader, ood_val_loader, ood_features, id_features, name=key, net=net)
        id_dists[key].append(np.mean(np.square(id_features[key] - id_features['mocov2'])))
        ood_dists[key].append(np.mean(np.square(ood_features[key] - ood_features['mocov2'])))
        
    
        
# For each model, get config path, checkpoint path
# Set model to val mode
# If iteration is 1, get features for mocov2 model
# Get features for other models

../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-0_run0/config.json
loading data
computing features
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-0_run0/config.json
loading data
computing features
../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-0_use_net_val_mode-True_run0/config.json
loading data
computing features
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-1_run1/config.json
loading data
computing features
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-1_run1/config.json
loading data
computing features
../logs/lp_then_ft_valmode_living17_resnet50/optimizer.args.lr-0.0001_seed-1_use_net_val_mode-True_run1/config.json
loading data
computing features
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.001_seed-2_run2/config.json
loading data
computing features
../logs/full_ft_living17_resnet50/optimizer.args.lr-0.0001_seed-2_run2/config.json
loading data
computing features
../logs/lp_then_f

In [76]:
for key in model_path_fmts:
    id_dists[key].append(np.mean(np.square(id_features[key] - id_features['mocov2'])))
    ood_dists[key].append(np.mean(np.square(ood_features[key] - ood_features['mocov2'])))

In [95]:
for key in id_dists:
    print(key)
    print('{:.2f} ({:.2f})'.format(100*np.mean(id_dists[key]), 100*np.std(id_dists[key]) / np.sqrt(len(id_dists[key])) * 1.645))
    print('{:.2f} ({:.2f})'.format(100*np.mean(ood_dists[key]), 100*np.std(ood_dists[key]) / np.sqrt(len(ood_dists[key])) * 1.645))
#     print(str(np.mean(ood_dists[key])) + ' (' + str(np.std(ood_dists[key]) / np.sqrt(len(ood_dists[key])) * 1.645) + ')')
    

ft_living17_best_path_fmt
1.88 (0.01)
1.67 (0.01)
ft_living17_samelr_path_fmt
1.33 (0.02)
1.03 (0.01)
lpft_living17_path_fmt
0.11 (0.01)
0.09 (0.01)


In [81]:
ood_dists

{'ft_living17_best_path_fmt': [0.016827703, 0.016777152, 0.016641773],
 'ft_living17_samelr_path_fmt': [0.010265077, 0.010212453, 0.010457616],
 'lpft_living17_path_fmt': [0.00088527414, 0.00077893224, 0.0010297422]}