In [1]:
"""
Evaluate performance of clean LinearProbe on noisy data 

Prereqs:
    Models: 
        - clean CLIP model (A)
        - finetuned CLIP model on noisy data (B)
    Linear Heads: 
        - linear fit on cleanCLIP embeddings (1)
        - linear fit on noisyCLIP embeddings (2)
And then we evaluate:  {A1(clean), B1(noisy), A1(noisy))
     A   B 
    -------
  1|   |   |
   |---|----
  2|   |   |
    -------
"""
import torch 
import torch.nn as nn 
import numpy as np 
import os 
from torch.utils.data import DataLoader
from utils import * 
# imagenet 
from pytorch_lightning import Trainer, LightningDataModule,LightningModule, seed_everything
from pytorch_lightning.metrics import Accuracy
import glob
from noisy_clip_dataparallel import NoisyCLIP
from baselines import Baseline
from tabulate import tabulate

print()




In [2]:
### Jupyter evals should be quick 
class HParam:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v) 

In [18]:
DATA_DIR = os.path.expanduser('~/datasets/ImageNet100/')

def get_xform(distort_type, param=None):
    if distort_type == None:
        return ImageNetBaseTransformVal(HParam(encoder='clip'))
    
    elif distort_type == 'randommask':
        hparam = HParam(encoder='clip', distortion=distort_type, 
                        percent_missing=param, fixed_mask=False)
    elif distort_type == "gaussiannoise":
        hparam = HParam(encoder='clip', distortion=distort_type, std=param, fixed_mask=False)
    elif distort_type == "gaussianblur":
        hparam = HParam(encoder='clip', distortion=distort_type, kernel_size=param[0], sigma=param[1])
    return ImageNetDistortVal(hparam)


def get_valset(distort_type, param=None):
    
    return DataLoader(ImageNet100(root=DATA_DIR, split='val', 
                                  transform=get_xform(distort_type, param)), 
                      batch_size=128, num_workers=4, 
                      pin_memory=True, shuffle=False)



    

In [4]:
test_batch = next(iter(get_valset(None)))

In [5]:
baseline_ckpt = glob.glob('/tmp/Logs/Contrastive-Inversion/RN101_CLEAN_CLIP_LIN/checkpoints/*')[0]

baseline = Baseline.load_from_checkpoint(baseline_ckpt)
clean_backbone = baseline.encoder.feature_extractor
clean_classifier = baseline.encoder.classifier

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [6]:
clean_classifier

Linear(in_features=512, out_features=100, bias=True)

In [7]:
clean_classifier = clean_classifier.cuda()

In [24]:
test_ckpt = sorted(glob.glob('models_table1/rand*'))[0]
sorted(glob.glob('models_table1/*'))


['models_table1/blur21.ckpt',
 'models_table1/blur37.ckpt',
 'models_table1/noise01.ckpt',
 'models_table1/noise03.ckpt',
 'models_table1/noise05.ckpt',
 'models_table1/rand50.ckpt',
 'models_table1/rand75.ckpt',
 'models_table1/rand90.ckpt']

In [14]:
test_model = NoisyCLIP.load_from_checkpoint(test_ckpt)

In [41]:
class TestComp(LightningModule):
    def __init__(self, backbone, classifier):
        super().__init__()
        self.backbone = backbone 
        self.classifier = classifier 
        
        self.test_top_1 = Accuracy(top_k=1)
        self.test_top_5 = Accuracy(top_k=5)
        self.output_dict = {}
        
    def forward(self, x):
        return self.classifier(self.backbone(x))
    
    def test_step(self, batch, batch_idx):
        x, y = batch 
        logits = self.forward(x)
        pred_probs = logits.softmax(dim=-1) 
        self.test_top_1(pred_probs, y)
        self.test_top_5(pred_probs, y) 
        
    def test_epoch_end(self, outputs):
        top1 = self.test_top_1.compute()
        top5 = self.test_top_5.compute() 
        
        self.log('top1', top1)
        self.log('top5', top5)
        self.test_top_1.reset()
        self.test_top_5.reset()
        self.output_dict=  {'top1': top1.item(), 'top5': top5.item()}
        
def eval_combo(backbone, classifier, data):
    trainer = Trainer(gpus=[0])
    comp = TestComp(backbone, classifier)
    trainer.test(model=comp, test_dataloaders=data)
    return comp.output_dict
    
        

In [27]:
MODEL_DISTORT_PAIRS = \
[('models_table1/blur21.ckpt', ('gaussianblur', (21, 5))),
 ('models_table1/blur37.ckpt', ('gaussianblur', (37, 9))), 
 ('models_table1/noise01.ckpt',('gaussiannoise', 0.1)),
 ('models_table1/noise03.ckpt',('gaussiannoise', 0.3)),
 ('models_table1/noise05.ckpt',('gaussiannoise', 0.5)),
 ('models_table1/rand50.ckpt', ('randommask', 0.50)),
 ('models_table1/rand75.ckpt', ('randommask', 0.75)),
 ('models_table1/rand90.ckpt', ('randommask', 0.90))]

In [48]:
OUTPUTS = []
for model_name, distort_param in MODEL_DISTORT_PAIRS: 
    encoder = NoisyCLIP.load_from_checkpoint(model_name).noisy_visual_encoder
    
    eval_distorted = eval_combo(encoder, clean_classifier, get_valset(*distort_param))
    eval_clean = eval_combo(encoder, clean_classifier, get_valset(None))
    OUTPUTS.append((model_name, eval_distorted, eval_clean))

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7832000255584717, 'top5': 0.9462000131607056}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.010200000368058681, 'top5': 0.04879999905824661}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7307999730110168, 'top5': 0.9064000248908997}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.010400000028312206, 'top5': 0.054999999701976776}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7724000215530396, 'top5': 0.9422000050544739}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.41440001130104065, 'top5': 0.6615999937057495}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7570000290870667, 'top5': 0.9318000078201294}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.39100000262260437, 'top5': 0.6398000121116638}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7052000164985657, 'top5': 0.8985999822616577}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.33000001311302185, 'top5': 0.5640000104904175}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.8014000058174133, 'top5': 0.9495999813079834}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.6033999919891357, 'top5': 0.8348000049591064}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.76419997215271, 'top5': 0.9314000010490417}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.36880001425743103, 'top5': 0.5983999967575073}
--------------------------------------------------------------------------------


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.7811999917030334, 'top5': 0.9430000185966492}
--------------------------------------------------------------------------------


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'top1': 0.03720000013709068, 'top5': 0.10939999669790268}
--------------------------------------------------------------------------------


In [49]:
OUTPUTS

[('models_table1/blur21.ckpt',
  {'top1': 0.7832000255584717, 'top5': 0.9462000131607056},
  {'top1': 0.010200000368058681, 'top5': 0.04879999905824661}),
 ('models_table1/blur37.ckpt',
  {'top1': 0.7307999730110168, 'top5': 0.9064000248908997},
  {'top1': 0.010400000028312206, 'top5': 0.054999999701976776}),
 ('models_table1/noise01.ckpt',
  {'top1': 0.7724000215530396, 'top5': 0.9422000050544739},
  {'top1': 0.41440001130104065, 'top5': 0.6615999937057495}),
 ('models_table1/noise03.ckpt',
  {'top1': 0.7570000290870667, 'top5': 0.9318000078201294},
  {'top1': 0.39100000262260437, 'top5': 0.6398000121116638}),
 ('models_table1/noise05.ckpt',
  {'top1': 0.7052000164985657, 'top5': 0.8985999822616577},
  {'top1': 0.33000001311302185, 'top5': 0.5640000104904175}),
 ('models_table1/rand50.ckpt',
  {'top1': 0.8014000058174133, 'top5': 0.9495999813079834},
  {'top1': 0.6033999919891357, 'top5': 0.8348000049591064}),
 ('models_table1/rand75.ckpt',
  {'top1': 0.76419997215271, 'top5': 0.93140

In [57]:
top1s = [[_[0].split('/')[-1].split('.')[0], _[1]['top1'], _[2]['top1']] for _ in OUTPUTS]
print(tabulate(top1s, headers=['model', 'distort_acc', 'clean_acc'], floatfmt=".3f"))

model      distort_acc    clean_acc
-------  -------------  -----------
blur21           0.783        0.010
blur37           0.731        0.010
noise01          0.772        0.414
noise03          0.757        0.391
noise05          0.705        0.330
rand50           0.801        0.603
rand75           0.764        0.369
rand90           0.781        0.037


In [59]:
top5s = [[_[0].split('/')[-1].split('.')[0], _[1]['top5'], _[2]['top5']] for _ in OUTPUTS]
print(tabulate(top5s, headers=['model', 'distort_acc', 'clean_acc'], floatfmt=".3f"))

model      distort_acc    clean_acc
-------  -------------  -----------
blur21           0.946        0.049
blur37           0.906        0.055
noise01          0.942        0.662
noise03          0.932        0.640
noise05          0.899        0.564
rand50           0.950        0.835
rand75           0.931        0.598
rand90           0.943        0.109
