In [None]:
from mmcv import Config
import argparse
from openselfsup.models import build_model
from openselfsup.datasets import build_dataset
import os
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
import pdb
import torch
import pickle
from tqdm import tqdm
import torch.nn as nn
from sklearn import svm
import pandas as pd
from torchvision import transforms
import seaborn as sns
import matplotlib.pyplot as plt

from openselfsup.analysis.local_paths_from_func import get_model_kwargs_from_setting_func
import openselfsup.analysis.response_extractor as response_extractor

import dobs.tools as tools
import dobs.folder as folder
import dobs.folder_list as folder_list
    
print(torch.cuda.get_device_name(0))

In [None]:
def l2_normalize(z):
    z = z / (torch.norm(z, p=2, dim=1, keepdim=True) + 1e-10)
    return z
    
'''
CHENGXU transform function (from tools.extract_face_img_states)
'''
def get_eval_transforms():
    norm_cfg = dict(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(**norm_cfg),
            ])
    return transform

transform = get_eval_transforms()

In [None]:
'''
DOBS transform function (from Katherina)
'''

# image preprocessing steps     
IMAGE_RESIZE=256
IMAGE_SIZE=224
GRAYSCALE_PROBABILITY=0.2
resize_transform      = transforms.Resize(IMAGE_RESIZE)
random_crop_transform = transforms.RandomCrop(IMAGE_SIZE)
center_crop_transform = transforms.CenterCrop(IMAGE_SIZE)
grayscale_transform   = transforms.RandomGrayscale(p=GRAYSCALE_PROBABILITY)
normalize             = transforms.Normalize(mean=[0.5]*3,std=[0.5]*3)

invert = transforms.RandomVerticalFlip(p=1.0)

transform = transforms.Compose([resize_transform, 
                                            random_crop_transform, 
                                            grayscale_transform, 
                                            transforms.ToTensor(),
                                            normalize,
                                           ])

invert_transform = transforms.Compose([resize_transform, 
                                            random_crop_transform, 
                                            grayscale_transform, 
                                            transforms.ToTensor(),
                                            normalize,
                                            invert
                                           ])

In [None]:
def get_acts(model_file = 'configs/new_pplns/supervised/in_is224.py:r50_ep100_dobs_face_s0', foor='faces'):


    _dat = 'faces' if 'face' in foor else 'objects'
    _dat = 'cars' if 'car' in foor else _dat
    
    save_name = model_file.split(':')[1]
    
    md_kwgs = get_model_kwargs_from_setting_func(model_file)
    model = build_model(md_kwgs['loaded_cfg'].model)
    
    layers = []
    for name, module in model.named_modules():
    
        # #if len(name.split('.')) > 2 and name.split('.')[1] == '0':
        # if 'relu' in name and 'neck' not in name and 'encoder_k' not in name and 'target_net' not in name:
        if '4.2' in name and 'relu' in name:
            if 'moco' in model_file or 'byol' in model_file:
                name = '.'.join(['backbone'] + name.split('.')[2:])
            layers.append(name)
    
    extractor = response_extractor.ResponseExtractor(
                layers=layers,
                **md_kwgs)
    
    test_data_dir=['/om2/group/nklab/shared/datasets/dobs_objface1000/%s_1000/test/'%_dat]
    
    ImageFolder = folder_list.ImageFolder
    
    dataset = ImageFolder(root=test_data_dir, 
                                  max_samples={'%s_1000'%_dat: 10},
                                  maxout=True,
                                  read_seed=None,
                                  transform=transform,
                                  includePaths=False)
    
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                      batch_size=10,
                                                      shuffle=False,
                                                      num_workers=4,
                                                      pin_memory=True)
    
    all_activations = []
    
    max_batches=100
    for step, batch in enumerate(tqdm(data_loader, desc='act/grad')):
        if max_batches is not None:
            if step == max_batches:
                break
        x,y = batch
        with torch.no_grad():  
            act = extractor.get_activations(x)
    
            # model(x, mode='test')
    
            # act = act.cpu().numpy()
            all_activations.append(act)
    
    # print(all_activations[0]['backbone.layer1.1.relu'].shape)
    
    ### Run activations through avg pool ###
    m = nn.AdaptiveAvgPool2d((1, 1))
    act_dict = {layer: [] for layer in layers}
    
    print(len(all_activations))
    
    for batch in range(len(all_activations)):
        for layer, act in all_activations[batch].items():
            
            z = m(torch.tensor(act))
            z = l2_normalize(z)
            if batch == 0:
                act_dict[layer] = z
            else:
                act_dict[layer] = np.vstack((act_dict[layer], z))
    
    for layer, act in act_dict.items():
        act_dict[layer] = act.reshape(act.shape[0], -1)
    
    # act_dict['backbone.layer1.1.relu'].shape
    return act_dict

In [None]:
def run_svm(act_dict):
    
    perf_dict = {}
    
    print('=========== starting %s %s ==================='%(model_file,save_name))
    for layer, act in act_dict.items():
        # run SVM decoding
        num_ids = 100
        num_reps_id = 10
        
        num_samples = num_ids*num_reps_id
        
        
        indTest = np.arange(0,num_samples,num_reps_id)
        indAll = np.arange(0,num_samples)
        
        x = np.arange(0,num_ids)
        trainCat = np.repeat(x,num_reps_id-1)
        
        perf_fold = np.zeros(shape=(num_reps_id,))
        
        for iFold in tqdm(range(num_reps_id)):
        
            indTrain = np.setdiff1d(indAll,indTest+iFold)
        
            dataTest = act[indTest+iFold,:]
            dataTrain = act[indTrain,:]
            
            clf = svm.LinearSVC(dual='auto')
            clf.fit(dataTrain,trainCat)
        
            dec = clf.predict(dataTest)
            
            diff = dec - x
            perf = np.where(diff == 0)[0]
            perf = len(perf)/num_ids
        
            perf_fold[iFold] = perf
            
        perf_dict[layer] = perf_fold
        print(layer, np.mean(perf_fold))

    return perf_dict

In [None]:
foor='car'

model_list = ['configs/new_pplns/supervised/in_is224.py:r50_ep100_mini_%s_s0'%foor,
'configs/new_pplns/simclr/r50.py:r50_ep100_mini_%s'%foor,
]

for model_file in model_list:
    save_name = model_file.split(':')[1]
    
    act_dict = get_acts(model_file, foor)
    break
    perf_dict = run_svm(act_dict)

    if 'moco' in model_file:
        save_name = 'moco_%s'%save_name
    elif 'simclr' in model_file:
        save_name = 'simclr_%s'%save_name
    elif 'byol' in model_file:
        save_name = 'byol_%s'%save_name
    elif 'dino' in model_file:
        save_name = 'dino_%s'%save_name
    elif 'supervised' in model_file:
        save_name = 'supervised_%s'%save_name

    with open('/om2/user/amarvi/FACE/saved_models/mini_svm_perf/%s.pkl'%save_name, 'wb') as f:
        pickle.dump(perf_dict, f)

In [None]:
act_dict['backbone.layer4.2.relu'].shape

In [None]:
root = '/om2/user/amarvi/FACE/saved_models/mini_svm_perf/'
cols = ['model', 'dataset', 'size', 'perf', 'layer']
df = pd.DataFrame(columns = cols)

for sub_dir in os.listdir(root):
    model = sub_dir.split('_')[0]
    if model not in ['moco', 'byol', 'simclr', 'supervised']:
        continue
    dataset = sub_dir.split('_')[4].split('.')[0]
    dir = os.path.join(root, sub_dir)
    with open(dir, 'rb') as f:
        dat = pickle.load(f)
        print(sub_dir, list(dat)[-1])
        # acts[sub_dir] = dat[list(dat)[-1]]
        df = pd.concat([df, pd.DataFrame([[model, dataset, '100', np.array(dat[list(dat)[-1]]), list(dat)[-1]]], columns=df.columns)], ignore_index=True)

In [None]:
df

In [None]:
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# df.to_pickle('/om2/user/amarvi/FACE/saved_models/svm_mini_all.pkl')
df = pd.read_pickle('/om2/user/amarvi/FACE/saved_models/svm_mini_exploded.pkl')


In [None]:
models = ['supervised', 'simclr']
custom_dict = {model_name: index for index, model_name in enumerate(models)}

# Filter the DataFrame
filtered_df = df.loc[(df['model'].isin(models)) & (df['size'].isin(['100']))]
print("Filtered DataFrame:\n", filtered_df)

# Explode the `perf` column
exploded_df = filtered_df.explode('perf')
exploded_df['perf'] = exploded_df['perf'].astype(float)
print("Exploded DataFrame:\n", exploded_df)

# # Group and Aggregate
grouped_df = exploded_df.groupby(['model', 'dataset'])['perf'].agg(['mean', 'std']).reset_index()
print("Grouped and Aggregated DataFrame:\n", grouped_df)

# Sort with Custom Key
sorted_df = grouped_df.sort_values(by=['model'], key=lambda x: x.map(custom_dict))
print("Sorted DataFrame:\n", sorted_df)

In [None]:
c1 = (1,0,0)
c2 = (1,0.65,0)
c3 = (0.42,0.00,0.50)
custom_dict = {'supervised': 0, 'simclr': 1} 

adict = {2: 1, 4: 0.7, 6: 0.5}
models = ['supervised', 'simclr']
order = ['obj', 'face', 'car']

grouped_df = df.loc[(df['model'].isin(models)) & (df['size'].isin(['100']))].groupby(['model', 'dataset'])['perf'].agg(['mean', 'std']).reset_index().sort_values(by=['model'], key=lambda x: x.map(custom_dict))
sns.barplot(data=grouped_df, x='dataset', y='mean', hue='model', order=order, edgecolor='gray')

sv = len(order)
for i in range(2*sv):
    match i:
        case 0:
            plt.gca().patches[i].set_fc(c2)
            plt.gca().patches[i].set_hatch('//')
        case 1:
            plt.gca().patches[i].set_fc(c1)
            plt.gca().patches[i].set_hatch('//')
        case 2:
            plt.gca().patches[i].set_fc(c3)
            plt.gca().patches[i].set_hatch('//')
        case 3:
            plt.gca().patches[i].set_fc((c2, 0.7))
        case 4:
            plt.gca().patches[i].set_fc((c1, 0.7))
        case 5:
            plt.gca().patches[i].set_fc((c3, 0.7))


plt.legend().remove()
# plt.xticks(range(sv), ['objects', 'faces', 'cars'])
plt.xticks([], [])
plt.xlabel('')
plt.ylim([0, 1])
plt.yticks(np.arange(0,1,0.2), [])
plt.ylabel('')
sns.despine(offset=5, bottom=True, trim=True)

for i in range(len(grouped_df)):
    x = plt.gca().patches[i].get_x()
    w = plt.gca().patches[i].get_width()
    h = plt.gca().patches[i].get_height()
    err = grouped_df.iloc[i]['std']
    plt.vlines(x+w/2, h-err, h+err, color='gray')

plt.title('')
plt
# plt.savefig('/om2/user/amarvi/FACE/figs/svm_cars.png', dpi=300, bbox_inches='tight')
plt.savefig('/om2/user/amarvi/FACE/figs/svm_cars_fixed.png', dpi = 1000, bbox_inches = 'tight', format='png', transparent=True)