## Exploring linear probing classification using the DINO class token as the feature space

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import f1_score, classification_report, confusion_matrix

from model_lora_vit import get_vit, load_lora_vit_from_dino_ckpt
from data_transforms import get_random_transform, get_deterministic_transform
from dataloader_tmed import TMED2

torch.hub.set_dir("../pretrained_weights")

In [2]:
# configure the GPU
device = 3 if torch.cuda.is_available() else "cpu"

# configure batch size for training
batch_size = 16

# configure linear probe training
embedding_dim = 384 # change this for other architectures
hidden_dim = 256 # >0 for optimizing a second linear layer
num_classes_AS = 3
lr = 0.0001
num_epochs = 30

In [3]:
# load the backbone model, ensure params are consistent with ckpt
experiment = 'imagenet'
if experiment == 'imagenet':
    ckpt_path = None
    lora_rank = 0
elif experiment == 'full':
    ckpt_path = '../logs/training_base/checkpoint.pth'
    lora_rank = 0
elif experiment == 'lora4':
    ckpt_path = '../logs/training_1/checkpoint0009.pth'
    lora_rank = 4
else:
    raise ValueError()

arch = 'vit_small'
patch_size = 8
if ckpt_path == None:
    # load the default DINO model
    model = get_vit(arch, patch_size, lora_rank=0)
else:
    model = get_vit(arch, patch_size, lora_rank)
    load_lora_vit_from_dino_ckpt(model, ckpt_path)
model.to(device).eval()

Using cache found in ../pretrained_weights/facebookresearch_dino_main


Initialized without LoRA


Lora_vit(
  (base): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): LayerNorm((384,

In [4]:
# load the dataset
transform = get_random_transform()
tr_dataset = TMED2(
    split = "train", # train/val/test/all/unlabeled
    transform = transform,
    parasternal_only = True,
    label_scheme_name = 'tufts',
)
tr_dataloader = torch.utils.data.DataLoader(tr_dataset, batch_size=batch_size, sampler = tr_dataset.class_sampler())

va_transform = get_deterministic_transform()
va_dataset = TMED2(
    split = "val", # train/val/test/all/unlabeled
    transform = va_transform,
    parasternal_only = True,
    label_scheme_name = 'tufts',
)
va_dataloader = torch.utils.data.DataLoader(va_dataset, batch_size=batch_size, shuffle=False)

[ 780  622 2444]
[0.00128205 0.00160772 0.00040917]


In [5]:
# instantiate trainable parameters, loss and optimizer

class Heads(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_classes_AS, **kwargs):
        super().__init__(**kwargs)
        if hidden_dim == 0:
            self.as_head = nn.Linear(embedding_dim, num_classes_AS)
            self.view_head = nn.Linear(embedding_dim, 2)
        else:
            self.as_head = nn.Sequential(nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_classes_AS))
            self.view_head = nn.Sequential(nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2))
    def forward(self, z):
        return self.as_head(z), self.view_head(z)
        
linear_probe = Heads(embedding_dim, hidden_dim, num_classes_AS).to(device)
# test the linear probe
print(linear_probe(torch.randn(batch_size, embedding_dim).to(device)))

tuner = torch.optim.Adam(linear_probe.parameters(), lr=lr)
loss_fcn = nn.CrossEntropyLoss(reduction="mean")

(tensor([[-0.1183, -0.0241, -0.2758],
        [ 0.1444, -0.0861, -0.1124],
        [-0.0420,  0.0587,  0.0451],
        [ 0.2020, -0.0809, -0.2038],
        [-0.3101, -0.3292, -0.1161],
        [ 0.5297, -0.0333, -0.1563],
        [ 0.3890, -0.0400,  0.1799],
        [ 0.1314, -0.0626,  0.1683],
        [-0.2073,  0.2356,  0.0367],
        [ 0.2517,  0.1855, -0.2147],
        [ 0.2597,  0.0150, -0.4146],
        [-0.1966, -0.0734, -0.0324],
        [ 0.1652, -0.3199, -0.2715],
        [ 0.0758,  0.1185,  0.0153],
        [ 0.0391,  0.0069,  0.0543],
        [-0.0601, -0.1089,  0.0041]], device='cuda:3',
       grad_fn=<AddmmBackward0>), tensor([[ 0.2966,  0.1669],
        [-0.0119,  0.1034],
        [-0.1613,  0.3144],
        [ 0.2857,  0.0116],
        [ 0.0330,  0.1729],
        [-0.2105,  0.0838],
        [ 0.0660,  0.3678],
        [-0.2163,  0.0780],
        [-0.3005,  0.0747],
        [-0.3021,  0.2450],
        [-0.0512, -0.1120],
        [-0.1931,  0.2655],
        [-0.1051,  

In [6]:
def forward(x):
    with torch.no_grad():
        features = model(x)
    logits_as, logits_v = linear_probe(features)
    return logits_as, logits_v
    
def train_batch(batch):
    x, [y, y_v] = batch
    x = x.to(device)
    y = y.to(device)
    y_v = y_v.to(device)

    tuner.zero_grad()
    logits_as, logits_v = forward(x)
    loss = loss_fcn(logits_as, y) + 0.1 * loss_fcn(logits_v, y_v)
    loss.backward()
    tuner.step()
    return loss.item()

def val_batch(batch):
    x, [y, y_v] = batch
    x = x.to(device)
    y = y.cpu().numpy()
    y_v = y_v.cpu().numpy()

    logits_as, logits_v = forward(x)
    pred_as = torch.argmax(logits_as, dim=1).cpu().numpy()
    pred_v = torch.argmax(logits_v, dim=1).cpu().numpy()
    return {'y_as':y, 'y_v':y_v, 'p_as':pred_as, 'p_v':pred_v}
            

In [7]:
for i in range(num_epochs):
    # train the model
    linear_probe.train()
    tr_loss = []
    for batch in tqdm(tr_dataloader):
        batch_loss = train_batch(batch)
        tr_loss.append(batch_loss)
    # validate the model
    linear_probe.eval()
    cache = {}
    for c in ['y_as', 'y_v', 'p_as', 'p_v']:
        cache[c] = []
    for batch in tqdm(va_dataloader):
        batch_outs = val_batch(batch)
        for k in batch_outs.keys():
            cache[k].extend(batch_outs[k])
    # evaluate acc and f1
    for k in cache.keys():
        cache[k] = np.array(cache[k]).squeeze()
    acc_as = sum(cache['y_as'] == cache['p_as'])/len(cache['y_as'])
    acc_v = sum(cache['y_v'] == cache['p_v'])/len(cache['y_v'])
    f1_as = f1_score(cache['y_as'], cache['p_as'], average='macro')
    f1_v = f1_score(cache['y_v'], cache['p_v'], average='macro')
    print("Epoch %3d: tr_loss %.3f, as acc/f1 %.3f/%.3f, view acc/f1 %.3f/%.3f" % (i, np.mean(tr_loss), acc_as, f1_as, acc_v, f1_v))
    print(confusion_matrix(cache['y_as'], cache['p_as']))
    print(confusion_matrix(cache['y_v'], cache['p_v']))

100%|█████████████████████████████████████████| 241/241 [00:39<00:00,  6.16it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:17<00:00,  4.81it/s]


Epoch   0: tr_loss 0.990, as acc/f1 0.536/0.457, view acc/f1 0.914/0.885
[[ 93  90 103]
 [ 22 100  65]
 [123 205 509]]
[[925  40]
 [ 73 272]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:45<00:00,  5.25it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:17<00:00,  4.64it/s]


Epoch   1: tr_loss 0.890, as acc/f1 0.496/0.453, view acc/f1 0.921/0.898
[[176  47  63]
 [ 77  76  34]
 [315 124 398]]
[[909  56]
 [ 48 297]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:45<00:00,  5.31it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.20it/s]


Epoch   2: tr_loss 0.840, as acc/f1 0.498/0.456, view acc/f1 0.916/0.881
[[168  72  46]
 [ 55  88  44]
 [232 209 396]]
[[955  10]
 [100 245]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:44<00:00,  5.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:16<00:00,  5.06it/s]


Epoch   3: tr_loss 0.810, as acc/f1 0.564/0.487, view acc/f1 0.913/0.876
[[119  79  88]
 [ 24  90  73]
 [114 193 530]]
[[957   8]
 [106 239]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:43<00:00,  5.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.26it/s]


Epoch   4: tr_loss 0.790, as acc/f1 0.582/0.504, view acc/f1 0.915/0.880
[[168  36  82]
 [ 37  68  82]
 [189 122 526]]
[[955  10]
 [101 244]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:45<00:00,  5.29it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:18<00:00,  4.51it/s]


Epoch   5: tr_loss 0.736, as acc/f1 0.631/0.518, view acc/f1 0.924/0.900
[[146  34 106]
 [ 24  55 108]
 [119  92 626]]
[[930  35]
 [ 64 281]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:40<00:00,  5.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:12<00:00,  6.75it/s]


Epoch   6: tr_loss 0.716, as acc/f1 0.656/0.520, view acc/f1 0.922/0.891
[[129  31 126]
 [ 15  48 124]
 [ 88  66 683]]
[[956   9]
 [ 93 252]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:44<00:00,  5.38it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:16<00:00,  4.86it/s]


Epoch   7: tr_loss 0.713, as acc/f1 0.619/0.524, view acc/f1 0.937/0.914
[[134  47 105]
 [ 14  75  98]
 [110 125 602]]
[[952  13]
 [ 70 275]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:41<00:00,  5.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:12<00:00,  6.33it/s]


Epoch   8: tr_loss 0.691, as acc/f1 0.643/0.515, view acc/f1 0.934/0.914
[[161  18 107]
 [ 28  41 118]
 [141  56 640]]
[[933  32]
 [ 54 291]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:43<00:00,  5.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:16<00:00,  4.98it/s]


Epoch   9: tr_loss 0.692, as acc/f1 0.560/0.505, view acc/f1 0.906/0.863
[[170  56  60]
 [ 31  89  67]
 [183 179 475]]
[[960   5]
 [118 227]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:39<00:00,  6.13it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:14<00:00,  5.58it/s]


Epoch  10: tr_loss 0.657, as acc/f1 0.616/0.522, view acc/f1 0.928/0.902
[[165  30  91]
 [ 33  62  92]
 [157 100 580]]
[[946  19]
 [ 75 270]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:43<00:00,  5.48it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:14<00:00,  5.66it/s]


Epoch  11: tr_loss 0.657, as acc/f1 0.587/0.514, view acc/f1 0.930/0.903
[[144  63  79]
 [ 20  86  81]
 [116 182 539]]
[[955  10]
 [ 82 263]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:40<00:00,  5.94it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:13<00:00,  6.21it/s]


Epoch  12: tr_loss 0.638, as acc/f1 0.561/0.508, view acc/f1 0.927/0.899
[[206  36  44]
 [ 47  78  62]
 [227 159 451]]
[[956   9]
 [ 86 259]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:39<00:00,  6.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:12<00:00,  6.46it/s]


Epoch  13: tr_loss 0.608, as acc/f1 0.635/0.531, view acc/f1 0.931/0.904
[[140  41 105]
 [ 18  67 102]
 [ 93 119 625]]
[[954  11]
 [ 80 265]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:42<00:00,  5.70it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:14<00:00,  5.76it/s]


Epoch  14: tr_loss 0.602, as acc/f1 0.627/0.532, view acc/f1 0.929/0.902
[[146  39 101]
 [ 20  71  96]
 [113 119 605]]
[[951  14]
 [ 79 266]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:41<00:00,  5.79it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.10it/s]


Epoch  15: tr_loss 0.585, as acc/f1 0.598/0.478, view acc/f1 0.927/0.898
[[ 90  73 123]
 [  9  71 107]
 [ 55 160 622]]
[[954  11]
 [ 85 260]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:45<00:00,  5.33it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:17<00:00,  4.70it/s]


Epoch  16: tr_loss 0.583, as acc/f1 0.556/0.503, view acc/f1 0.939/0.918
[[178  52  56]
 [ 37  87  63]
 [205 168 464]]
[[947  18]
 [ 62 283]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:40<00:00,  5.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.38it/s]


Epoch  17: tr_loss 0.575, as acc/f1 0.639/0.538, view acc/f1 0.931/0.910
[[140  49  97]
 [ 14  70 103]
 [ 80 130 627]]
[[930  35]
 [ 55 290]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:47<00:00,  5.06it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:19<00:00,  4.21it/s]


Epoch  18: tr_loss 0.565, as acc/f1 0.626/0.531, view acc/f1 0.933/0.910
[[155  49  82]
 [ 18  67 102]
 [107 132 598]]
[[940  25]
 [ 63 282]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:38<00:00,  6.29it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00,  7.36it/s]


Epoch  19: tr_loss 0.547, as acc/f1 0.634/0.528, view acc/f1 0.934/0.910
[[154  44  88]
 [ 19  59 109]
 [114 105 618]]
[[948  17]
 [ 70 275]]


100%|███████████████████████████████████████████████████████████████████████████████| 241/241 [00:47<00:00,  5.07it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 82/82 [00:18<00:00,  4.48it/s]


Epoch  20: tr_loss 0.540, as acc/f1 0.597/0.516, view acc/f1 0.934/0.909
[[157  52  77]
 [ 23  72  92]
 [123 161 553]]
[[951  14]
 [ 73 272]]


  0%|                                                                                         | 0/241 [00:00<?, ?it/s]


KeyboardInterrupt: 