In [1]:
import sys
sys.path.append('/om2/user/amarvi/dino/')

import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import svm
import pandas as pd
import pickle


import dobs.tools as tools
import dobs.folder as folder
import dobs.folder_list as folder_list
# from dobs.utils import dobs_tranform

import utils
import vision_transformer as vits

  warn(f"Failed to load image Python extension: {e}")


In [2]:
'''
DOBS transform function (from dobs.utils)
'''

# image preprocessing steps     
IMAGE_RESIZE=256
IMAGE_SIZE=224
GRAYSCALE_PROBABILITY=0.2
resize_transform      = transforms.Resize(IMAGE_RESIZE)
random_crop_transform = transforms.RandomCrop(IMAGE_SIZE)
center_crop_transform = transforms.CenterCrop(IMAGE_SIZE)
grayscale_transform   = transforms.RandomGrayscale(p=GRAYSCALE_PROBABILITY)
normalize             = transforms.Normalize(mean=[0.5]*3,std=[0.5]*3)

invert = transforms.RandomVerticalFlip(p=1.0)

transform = transforms.Compose([resize_transform, 
                                            random_crop_transform, 
                                            grayscale_transform, 
                                            transforms.ToTensor(),
                                            normalize,
                                           ])

In [6]:
dat = 'obj'
all_acts = get_activations(dat=dat)
perf_dict = run_svm(all_acts)

with open('/om2/user/amarvi/FACE/saved_models/svm_perf/dino_%s.pkl'%dat, 'wb') as f:
    pickle.dump(perf_dict, f)

In [3]:
def get_activations(dat='face', arch='vit_small', patches=16):
    arch = arch
    patches = patches
    dat = 'face'
    ckpt_pth = f'/om2/user/amarvi/dino/saved_models/{dat}400_dino/checkpoint.pth'

    test_data_dir=[f'/om2/group/nklab/shared/datasets/dobs_objface1000/{dat}s_1000/test/']
    ImageFolder = folder_list.ImageFolder
    dataset = ImageFolder(root=test_data_dir, 
                                  max_samples={f'{dat}s_1000': 10},
                                  maxout=True,
                                  read_seed=None,
                                  transform=transform,
                                  includePaths=False)
    
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                                      batch_size=10,
                                                      shuffle=False,
                                                      num_workers=4,
                                                      pin_memory=True)
    
    
    # load in model
    model = vits.__dict__[arch](patch_size=patches, num_classes=0)
    model.cuda()
    model.eval()
    utils.load_pretrained_weights(model, ckpt_pth, 'student', arch, patches)

    all_activations = np.empty((100, 10, 12, 384))
    max_batches=100
    for step, batch in enumerate(tqdm(data_loader, desc='act/grad')):
        if max_batches is not None:
            if step == max_batches:
                break
        x,y = batch
        x = x.cuda()
        with torch.no_grad():
            out = model.get_intermediate_layers(x,n=12)
            for idx, layer_activation in enumerate(out):
                clss_token = layer_activation[:, 0, :].squeeze()
                clss_token = clss_token.detach().cpu().numpy()
                all_activations[step, :, idx, :] = clss_token

    return all_activations.reshape((-1, 12, 384))

In [4]:
def run_svm(acts):
    
    perf_dict = {}
    
    print('=========== starting SVM ===================')
    for layer in range(acts.shape[1]):
        act = acts[:, layer, :]
        num_ids = 100
        num_reps_id = 10
        num_samples = num_ids*num_reps_id

        indTest = np.arange(0,num_samples,num_reps_id)
        indAll = np.arange(0,num_samples)
        
        x = np.arange(0,num_ids)
        trainCat = np.repeat(x,num_reps_id-1)
        perf_fold = np.zeros(shape=(num_reps_id,))
        
        for iFold in tqdm(range(num_reps_id)):
            indTrain = np.setdiff1d(indAll,indTest+iFold)
        
            dataTest = act[indTest+iFold,:]
            dataTrain = act[indTrain,:]
            
            clf = svm.LinearSVC(dual='auto')
            clf.fit(dataTrain,trainCat)
        
            dec = clf.predict(dataTest)
            
            diff = dec - x
            perf = np.where(diff == 0)[0]
            perf = len(perf)/num_ids
        
            perf_fold[iFold] = perf
            
        perf_dict[layer] = perf_fold
        print(layer, np.mean(perf_fold))

    return perf_dict