In [1]:
import sys 
sys.path.append('..')
from src.clip import get_image_features, define_model, feature_dim
from src.build_classifier import get_classifier
from src.train_clf import train

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import clip
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as FT
import torch

In [3]:
sys.path.append('..')
from cifar.cifarRawCorrupted import get_original_loaders, get_corrupt_loaders

In [4]:
device = 'cuda:0'
model = define_model(device=device)

In [5]:
c_loader = get_corrupt_loaders(model_name='imagebind', severity=1)


In [6]:
u,kt = next(iter(c_loader))

In [7]:
get_image_features(model, u).shape

torch.Size([64, 512])

In [8]:
clip_clf = get_classifier(feature_dim, output_classes=10, n_layers=1).to(device)
train_loader, val_loader, test_loader = get_original_loaders(batch_size=1024, model_name='blip') 
test_corrupt_loader = get_corrupt_loaders(batch_size=1024, model_name='blip')

Files already downloaded and verified
Files already downloaded and verified


In [9]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(clip_clf.parameters(), lr=0.001)
n_epochs = 2

In [10]:
losses, accs, val_losses, val_accs = train(model, clip_clf, optim=optim, loss_fn=loss_fn,
                                           train_loader=train_loader, val_loader=val_loader,
                                           feature_fn=get_image_features, epochs=n_epochs, device=device) #TODO resize im in clip transforms

  input = module(input)


initial loss 2.3039612054824827 and initial accuracy 0.025597894564270973
 train loss: 1.7583475351333617, val loss: 1.5312160491943358, Train accuracy 0.868701159954071, val accuracy 0.9476920962333679 
 train loss: 1.5214557319879531, val loss: 1.5154974460601807, Train accuracy 0.9508056640625, val accuracy 0.9553990364074707 


In [11]:
torch.save(clip_clf.state_dict(), '../saved_models/clip_clf_vitb.pth')


In [12]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
        
    
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(emb_model, model, test_loader, feature_fn, device='cuda'):
    eval_acc = []
    eval_losses = []
    for eval_batch in test_loader:
        if len(eval_batch)>2:
            _, ims, labels = eval_batch
        else: 
            ims, labels = eval_batch
        ims, labels = ims.to(device), labels.to(device)
        with torch.no_grad():
            features = feature_fn(emb_model, ims).squeeze()
            preds = model(features)
            val_acc = get_acc(labels.view(-1,), preds)
        
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)
            # 
test_acc_orig = racc =  get_test_acc(model, clip_clf, test_loader, get_image_features, device=device,)
 
print(test_acc_orig)

0.9466179


In [14]:
corrupts_dict = {}
corrupt_g_acc = []
for cr in ['gaussian_noise', 'speckle_noise', 'impulse_noise', 'shot_noise', ]:
    corrupts_dict[cr] = {}
    for sev in [1, 2, 3, 4, 5]:
        test_loader_corrupt = get_corrupt_loaders(batch_size=1024, corruption_type=cr, severity=sev, model_name='blip')
        acc =  get_test_acc(model, clip_clf, test_loader_corrupt, get_image_features, device=device,)
                                

        corrupts_dict[cr][sev]=acc

In [16]:
corrupts_dict

{'gaussian_noise': {1: 0.8200175,
  2: 0.6643335,
  3: 0.5043985,
  4: 0.42891026,
  5: 0.37114358},
 'speckle_noise': {1: 0.8763732,
  2: 0.7686045,
  3: 0.69716597,
  4: 0.55485094,
  5: 0.4382573},
 'impulse_noise': {1: 0.92210215,
  2: 0.8777124,
  3: 0.8338388,
  4: 0.7114158,
  5: 0.5900191},
 'shot_noise': {1: 0.8765625,
  2: 0.803388,
  3: 0.6125757,
  4: 0.53303176,
  5: 0.39992028}}