In [1]:
%env CUDA_VISIBLE_DEVICES = '6'

env: CUDA_VISIBLE_DEVICES='6'


In [2]:
from custom_models import *

In [3]:
from collections import OrderedDict

import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.datasets.utils import download_url

In [4]:
sys.path.append('custom/slip_codebase')
import models
import utils
from tokenizer import SimpleTokenizer

In [5]:
slip_model_weights = {
    'ViT-S-SimCLR': 'https://dl.fbaipublicfiles.com/slip/simclr_small_25ep.pt',
    'ViT-S-CLIP': 'https://dl.fbaipublicfiles.com/slip/clip_small_25ep.pt',
    'ViT-S-SLIP': 'https://dl.fbaipublicfiles.com/slip/slip_small_25ep.pt',
    'ViT-S-SLIP-Ep100': 'https://dl.fbaipublicfiles.com/slip/slip_small_100ep.pt',
    'ViT-B-SimCLR': 'https://dl.fbaipublicfiles.com/slip/simclr_base_25ep.pt',
    'ViT-B-CLIP': 'https://dl.fbaipublicfiles.com/slip/clip_base_25ep.pt',
    'ViT-B-SLIP': 'https://dl.fbaipublicfiles.com/slip/slip_base_25ep.pt',
    'ViT-B-SLIP-Ep100': 'https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt',
    'ViT-L-SimCLR': 'https://dl.fbaipublicfiles.com/slip/simclr_large_25ep.pt',
    'ViT-L-CLIP': 'https://dl.fbaipublicfiles.com/slip/clip_large_25ep.pt',
    'ViT-L-SLIP': 'https://dl.fbaipublicfiles.com/slip/slip_large_25ep.pt',
    'ViT-L-SLIP-Ep100': 'https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt',
    'ViT-L-CLIP-CC12M': 'https://dl.fbaipublicfiles.com/slip/clip_base_cc12m_35ep.pt',
    'ViT-L-SLIP-CC12M': 'https://dl.fbaipublicfiles.com/slip/slip_base_cc12m_35ep.pt',
}

slip_weight_paths = {key: os.path.join('custom/slip_weights', value.split('/')[-1])
                                       for (key, value) in slip_model_weights.items()}

In [6]:
list(slip_model_weights.keys())

['ViT-S-SimCLR',
 'ViT-S-CLIP',
 'ViT-S-SLIP',
 'ViT-S-SLIP-Ep100',
 'ViT-B-SimCLR',
 'ViT-B-CLIP',
 'ViT-B-SLIP',
 'ViT-B-SLIP-Ep100',
 'ViT-L-SimCLR',
 'ViT-L-CLIP',
 'ViT-L-SLIP',
 'ViT-L-SLIP-Ep100',
 'ViT-L-CLIP-CC12M',
 'ViT-L-SLIP-CC12M']

In [7]:
target_imageset = 'oasis'
model_name = 'ViT-B-SimCLR'
train_type = 'slip'
model_string = '_'.join([model_name, train_type])
model_option = {'model_name': model_name,
                'train_type': train_type}

In [8]:
download_url(slip_model_weights[model_name], 'custom/slip_weights')

Using downloaded and verified file: custom/slip_weights/simclr_base_25ep.pt


In [9]:
if not os.path.exists(slip_weight_paths[model_name]):
    !wget -P custom/slip_weights {slip_model_weights[model_name]}

In [10]:
ckpt_path = slip_weight_paths[model_name]
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
    state_dict[k.replace('module.', '')] = v

# create model
old_args = ckpt['args']
print("=> creating model: {}".format(old_args.model))
model = getattr(models, old_args.model)(rand_embed=False,
    ssl_mlp_dim=old_args.ssl_mlp_dim, ssl_emb_dim=old_args.ssl_emb_dim)
model.load_state_dict(state_dict, strict=True)
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch']))

=> creating model: SIMCLR_VITB16
=> loaded resume checkpoint 'custom/slip_weights/simclr_base_25ep.pt' (epoch 25)


In [11]:
image_data = load_image_data(target_imageset)
response_data = load_response_data(target_imageset)

In [22]:
image_transforms = transforms.Compose([
  transforms.Resize(224),
  transforms.CenterCrop(224),
  lambda x: x.convert('RGB'),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

In [23]:
stimulus_loader = get_stimulus_loader(image_data.image_path, image_transforms)

In [None]:
stimulus_features = get_all_feature_maps(model.visual, inputs = stimulus_loader)

In [None]:
stimulus_features = get_feature_map_srps(stimulus_features, delete_originals = True)

In [None]:
reg_results = get_regression_results(model_option, stimulus_features, response_data, alpha_values = [1000])

In [None]:
max_transform(reg_results[reg_results['image_type'] == 'Combo'], group_vars = ['measurement', 'image_type'])