In [2]:
import os
import sys
import torch
import random
import configs
import numpy as np
import transformers
import torch.nn as nn
from PIL import Image
import tensorflow as tf
from tqdm.auto import tqdm
from training_utils import val_loss_accuracy, get_test_accuracy

2023-11-16 12:46:25.035579: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-16 12:46:25.035689: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-16 12:46:25.035727: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-16 12:46:25.045468: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  metric = datasets.load_metric("accuracy")


In [3]:
configs.set_seed(42)
device = configs.set_device(2)

There are 8 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


In [4]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msemenov-andrei-v[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import datasets

metric = datasets.load_metric("accuracy")

data

In [6]:
from tensorflow.keras.datasets import cifar100

(train_images, train_labels), (test_images, test_labels) = cifar100.load_data()


image_paths = []
labels = []

save_dir = 'cifar100_images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i in range(len(train_images)):
    image_path = os.path.join(save_dir, f"train_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, train_images[i])
    image_paths.append(image_path)
    labels.append(train_labels[i][0])


for i in range(len(test_images)):
    image_path = os.path.join(save_dir, f"test_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, test_images[i])
    image_paths.append(image_path)
    labels.append(test_labels[i][0])

print(len(train_images), "\n")
print(len(test_images), "\n")
print(len(image_paths))

50000 

10000 

60000


In [7]:
import dataset_utils

dataset = dataset_utils.CLIPDataset(list_image_path=image_paths, list_txt=labels)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=dataset_utils.collate_fn, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=dataset_utils.collate_fn, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=dataset_utils.collate_fn, pin_memory=True)

assert len(dataset) == len(image_paths)
print("Dataset size: {}".format(len(dataset)), "\n")

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])
print("Train set: {}".format(len(train_dataset)), "\n")
print("Validation set: {}".format(len(val_dataset)), "\n")
print("Test set: {}".format(len(test_dataset)))

Dataset size: 60000 

Train set: 50000 

Validation set: 3000 

Test set: 7000


In [8]:
def remove_prefixes(strings):
    prefixes = ['a', 'an', 'the']
    result = []

    for string in strings:
        words = string.split()
        if words[0].lower() in prefixes:
            result.append(' '.join(words[1:]))
        else:
            result.append(string)

    return result

with open("conceptnet_cifar10_filtered_new.txt", "r") as f:
    concepts = f.read().lower().split("\n")
    concepts = remove_prefixes(concepts)

In [8]:
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer

train_loader_preprocessed = dataset_utils.preprocess_loader(train_loader, concepts)
val_loader_preprocessed = dataset_utils.preprocess_loader(val_loader, concepts)
test_loader_preprocessed = dataset_utils.preprocess_loader(test_loader, concepts)

100%|███████████████████████████████████████| 1563/1563 [04:49<00:00,  5.40it/s]
100%|███████████████████████████████████████████| 94/94 [00:17<00:00,  5.47it/s]
100%|█████████████████████████████████████████| 219/219 [00:39<00:00,  5.51it/s]


model

In [12]:
class TuningCLIPhead(nn.Module):
    def __init__(self, concepts: list, model_name: str="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip = transformers.CLIPModel.from_pretrained(model_name)
        self.processor = transformers.CLIPProcessor.from_pretrained(model_name)
        for param in self.clip.parameters():
            param.requires_grad=False
        self.head = nn.Linear(len(concepts), 100, bias=False)

    def forward(self, **batch):
        out = self.clip(**batch).logits_per_image
        return self.head(out)

training

In [13]:
torch.cuda.empty_cache()

In [14]:
model = TuningCLIPhead(concepts=concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 12000


In [13]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def val_loss_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    val_losses = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        targets = torch.tensor(labels, dtype=torch.long)
        loss = criterion(logits, targets.to(device))
        val_losses.append(loss.item())
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)

    val_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    avg_val_loss = sum(val_losses) / len(val_losses)
    
    return val_accuracy, avg_val_loss

@torch.no_grad()
@torch.cuda.amp.autocast()
def get_test_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        targets = torch.tensor(labels, dtype=torch.long)
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)
        
    test_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    
    return test_accuracy

In [16]:
model.to(device)
run = wandb.init(project='cifar100-head', entity='semenov-andrei-v')

num_epochs = 60
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/60, Loss: 5.463306427001953, Validation Accuracy: {'accuracy': 0.07133333333333333}


Epoch 2/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/60, Loss: 4.903232574462891, Validation Accuracy: {'accuracy': 0.12666666666666668}


Epoch 3/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/60, Loss: 4.532017230987549, Validation Accuracy: {'accuracy': 0.16566666666666666}


Epoch 4/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/60, Loss: 4.275141716003418, Validation Accuracy: {'accuracy': 0.19866666666666666}


Epoch 5/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/60, Loss: 4.102478504180908, Validation Accuracy: {'accuracy': 0.23066666666666666}


Epoch 6/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/60, Loss: 3.9665961265563965, Validation Accuracy: {'accuracy': 0.25333333333333335}


Epoch 7/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/60, Loss: 3.8958752155303955, Validation Accuracy: {'accuracy': 0.273}


Epoch 8/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/60, Loss: 3.866090774536133, Validation Accuracy: {'accuracy': 0.29033333333333333}


Epoch 9/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/60, Loss: 3.8298683166503906, Validation Accuracy: {'accuracy': 0.29733333333333334}


Epoch 10/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/60, Loss: 3.7827839851379395, Validation Accuracy: {'accuracy': 0.31066666666666665}


Epoch 11/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/60, Loss: 3.7235217094421387, Validation Accuracy: {'accuracy': 0.321}


Epoch 12/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/60, Loss: 3.651564121246338, Validation Accuracy: {'accuracy': 0.33266666666666667}


Epoch 13/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/60, Loss: 3.5711171627044678, Validation Accuracy: {'accuracy': 0.33766666666666667}


Epoch 14/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/60, Loss: 3.487610101699829, Validation Accuracy: {'accuracy': 0.3446666666666667}


Epoch 15/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/60, Loss: 3.4069440364837646, Validation Accuracy: {'accuracy': 0.35133333333333333}


Epoch 16/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 16/60, Loss: 3.33453369140625, Validation Accuracy: {'accuracy': 0.359}


Epoch 17/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 17/60, Loss: 3.2750816345214844, Validation Accuracy: {'accuracy': 0.36466666666666664}


Epoch 18/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/60, Loss: 3.2300734519958496, Validation Accuracy: {'accuracy': 0.368}


Epoch 19/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 19/60, Loss: 3.1985647678375244, Validation Accuracy: {'accuracy': 0.37466666666666665}


Epoch 20/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 20/60, Loss: 3.1886956691741943, Validation Accuracy: {'accuracy': 0.37666666666666665}


Epoch 21/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 21/60, Loss: 3.1685791015625, Validation Accuracy: {'accuracy': 0.379}


Epoch 22/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 22/60, Loss: 3.145977020263672, Validation Accuracy: {'accuracy': 0.38466666666666666}


Epoch 23/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 23/60, Loss: 3.127363681793213, Validation Accuracy: {'accuracy': 0.38766666666666666}


Epoch 24/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 24/60, Loss: 3.115309476852417, Validation Accuracy: {'accuracy': 0.392}


Epoch 25/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 25/60, Loss: 3.109464168548584, Validation Accuracy: {'accuracy': 0.394}


Epoch 26/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 26/60, Loss: 3.1078193187713623, Validation Accuracy: {'accuracy': 0.3963333333333333}


Epoch 27/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 27/60, Loss: 3.1089394092559814, Validation Accuracy: {'accuracy': 0.3983333333333333}


Epoch 28/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 28/60, Loss: 3.111908435821533, Validation Accuracy: {'accuracy': 0.4}


Epoch 29/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 29/60, Loss: 3.115781307220459, Validation Accuracy: {'accuracy': 0.401}


Epoch 30/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 30/60, Loss: 3.1197266578674316, Validation Accuracy: {'accuracy': 0.403}


Epoch 31/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 31/60, Loss: 3.123175621032715, Validation Accuracy: {'accuracy': 0.4046666666666667}


Epoch 32/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 32/60, Loss: 3.1257710456848145, Validation Accuracy: {'accuracy': 0.4053333333333333}


Epoch 33/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 33/60, Loss: 3.1272764205932617, Validation Accuracy: {'accuracy': 0.4033333333333333}


Epoch 34/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 34/60, Loss: 3.1274819374084473, Validation Accuracy: {'accuracy': 0.4046666666666667}


Epoch 35/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 35/60, Loss: 3.126197338104248, Validation Accuracy: {'accuracy': 0.4063333333333333}


Epoch 36/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 36/60, Loss: 3.1233034133911133, Validation Accuracy: {'accuracy': 0.4063333333333333}


Epoch 37/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 37/60, Loss: 3.1187357902526855, Validation Accuracy: {'accuracy': 0.4076666666666667}


Epoch 38/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 38/60, Loss: 3.112337827682495, Validation Accuracy: {'accuracy': 0.4096666666666667}


Epoch 39/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 39/60, Loss: 3.10394549369812, Validation Accuracy: {'accuracy': 0.41}


Epoch 40/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 40/60, Loss: 3.093994379043579, Validation Accuracy: {'accuracy': 0.4093333333333333}


Epoch 41/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 41/60, Loss: 3.0836637020111084, Validation Accuracy: {'accuracy': 0.409}


Epoch 42/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 42/60, Loss: 3.0742428302764893, Validation Accuracy: {'accuracy': 0.41}


Epoch 43/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 43/60, Loss: 3.0660905838012695, Validation Accuracy: {'accuracy': 0.4116666666666667}


Epoch 44/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 44/60, Loss: 3.0588932037353516, Validation Accuracy: {'accuracy': 0.41333333333333333}


Epoch 45/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 45/60, Loss: 3.052229642868042, Validation Accuracy: {'accuracy': 0.4156666666666667}


Epoch 46/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 46/60, Loss: 3.0457992553710938, Validation Accuracy: {'accuracy': 0.417}


Epoch 47/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 47/60, Loss: 3.0394349098205566, Validation Accuracy: {'accuracy': 0.4156666666666667}


Epoch 48/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 48/60, Loss: 3.033052682876587, Validation Accuracy: {'accuracy': 0.4156666666666667}


Epoch 49/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 49/60, Loss: 3.0266506671905518, Validation Accuracy: {'accuracy': 0.4166666666666667}


Epoch 50/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 50/60, Loss: 3.020282745361328, Validation Accuracy: {'accuracy': 0.41833333333333333}


Epoch 51/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 51/60, Loss: 3.0140442848205566, Validation Accuracy: {'accuracy': 0.4186666666666667}


Epoch 52/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 52/60, Loss: 3.008023500442505, Validation Accuracy: {'accuracy': 0.42133333333333334}


Epoch 53/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 53/60, Loss: 3.00228214263916, Validation Accuracy: {'accuracy': 0.422}


Epoch 54/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 54/60, Loss: 2.996842384338379, Validation Accuracy: {'accuracy': 0.422}


Epoch 55/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 55/60, Loss: 2.9917078018188477, Validation Accuracy: {'accuracy': 0.42333333333333334}


Epoch 56/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 56/60, Loss: 2.9868929386138916, Validation Accuracy: {'accuracy': 0.4246666666666667}


Epoch 57/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 57/60, Loss: 2.9823927879333496, Validation Accuracy: {'accuracy': 0.426}


Epoch 58/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 58/60, Loss: 2.978238582611084, Validation Accuracy: {'accuracy': 0.42733333333333334}


Epoch 59/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 59/60, Loss: 2.9744746685028076, Validation Accuracy: {'accuracy': 0.427}


Epoch 60/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 60/60, Loss: 2.9711647033691406, Validation Accuracy: {'accuracy': 0.429}


  0%|          | 0/1563 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▂▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇████████████████████
Loss/Train,█▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/Validation,█▆▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Accuracy/Test,0.45096
Accuracy/Validation,0.429
Loss/Train,2.97116
Loss/Validation,2.53374


## with cifar100 concepts only

In [9]:
with open("conceptnet_cifar100_filtered_new.txt", "r") as f:
    concepts = f.read().lower().split("\n")
    concepts = remove_prefixes(concepts)

In [10]:
len(concepts)

944

In [11]:
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer

train_loader_preprocessed = dataset_utils.preprocess_loader(train_loader, concepts)
val_loader_preprocessed = dataset_utils.preprocess_loader(val_loader, concepts)
test_loader_preprocessed = dataset_utils.preprocess_loader(test_loader, concepts)

100%|███████████████████████████████████████| 1563/1563 [05:13<00:00,  4.99it/s]
100%|███████████████████████████████████████████| 94/94 [00:19<00:00,  4.92it/s]
100%|█████████████████████████████████████████| 219/219 [00:44<00:00,  4.89it/s]


In [19]:
torch.cuda.empty_cache()

In [20]:
model = TuningCLIPhead(concepts=concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 94400


In [21]:
model.to(device)
run = wandb.init(project='cifar100-head', entity='semenov-andrei-v')

num_epochs = 60
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, test_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/60, Loss: 26.086469650268555, Validation Accuracy: {'accuracy': 0.11066666666666666}


Epoch 2/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/60, Loss: 12.031149864196777, Validation Accuracy: {'accuracy': 0.16833333333333333}


Epoch 3/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/60, Loss: 12.257671356201172, Validation Accuracy: {'accuracy': 0.21066666666666667}


Epoch 4/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/60, Loss: 15.282087326049805, Validation Accuracy: {'accuracy': 0.184}


Epoch 5/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/60, Loss: 15.90857219696045, Validation Accuracy: {'accuracy': 0.221}


Epoch 6/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/60, Loss: 19.34119987487793, Validation Accuracy: {'accuracy': 0.21633333333333332}


Epoch 7/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/60, Loss: 22.17243194580078, Validation Accuracy: {'accuracy': 0.27466666666666667}


Epoch 8/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/60, Loss: 16.483795166015625, Validation Accuracy: {'accuracy': 0.24966666666666668}


Epoch 9/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/60, Loss: 17.13865852355957, Validation Accuracy: {'accuracy': 0.25766666666666665}


Epoch 10/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/60, Loss: 18.489322662353516, Validation Accuracy: {'accuracy': 0.29533333333333334}


Epoch 11/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/60, Loss: 16.34471893310547, Validation Accuracy: {'accuracy': 0.2633333333333333}


Epoch 12/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/60, Loss: 20.72008514404297, Validation Accuracy: {'accuracy': 0.30833333333333335}


Epoch 13/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/60, Loss: 14.050576210021973, Validation Accuracy: {'accuracy': 0.298}


Epoch 14/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/60, Loss: 14.777558326721191, Validation Accuracy: {'accuracy': 0.3383333333333333}


Epoch 15/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/60, Loss: 23.919715881347656, Validation Accuracy: {'accuracy': 0.33566666666666667}


Epoch 16/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 16/60, Loss: 22.782711029052734, Validation Accuracy: {'accuracy': 0.2976666666666667}


Epoch 17/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 17/60, Loss: 22.52137565612793, Validation Accuracy: {'accuracy': 0.326}


Epoch 18/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/60, Loss: 11.632498741149902, Validation Accuracy: {'accuracy': 0.3546666666666667}


Epoch 19/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 19/60, Loss: 16.93497657775879, Validation Accuracy: {'accuracy': 0.332}


Epoch 20/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 20/60, Loss: 15.75847053527832, Validation Accuracy: {'accuracy': 0.35333333333333333}


Epoch 21/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 21/60, Loss: 16.80501937866211, Validation Accuracy: {'accuracy': 0.3433333333333333}


Epoch 22/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 22/60, Loss: 24.949922561645508, Validation Accuracy: {'accuracy': 0.3536666666666667}


Epoch 23/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 23/60, Loss: 15.148921966552734, Validation Accuracy: {'accuracy': 0.34933333333333333}


Epoch 24/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 24/60, Loss: 22.779537200927734, Validation Accuracy: {'accuracy': 0.3263333333333333}


Epoch 25/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 25/60, Loss: 15.233345031738281, Validation Accuracy: {'accuracy': 0.36433333333333334}


Epoch 26/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 26/60, Loss: 19.45241928100586, Validation Accuracy: {'accuracy': 0.38866666666666666}


Epoch 27/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 27/60, Loss: 13.310646057128906, Validation Accuracy: {'accuracy': 0.3436666666666667}


Epoch 28/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 28/60, Loss: 12.36600399017334, Validation Accuracy: {'accuracy': 0.36433333333333334}


Epoch 29/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 29/60, Loss: 16.72597312927246, Validation Accuracy: {'accuracy': 0.37166666666666665}


Epoch 30/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 30/60, Loss: 15.903167724609375, Validation Accuracy: {'accuracy': 0.37}


Epoch 31/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 31/60, Loss: 14.029983520507812, Validation Accuracy: {'accuracy': 0.3463333333333333}


Epoch 32/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 32/60, Loss: 21.321746826171875, Validation Accuracy: {'accuracy': 0.357}


Epoch 33/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 33/60, Loss: 16.57905387878418, Validation Accuracy: {'accuracy': 0.4033333333333333}


Epoch 34/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 34/60, Loss: 16.83977508544922, Validation Accuracy: {'accuracy': 0.36}


Epoch 35/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 35/60, Loss: 12.50976848602295, Validation Accuracy: {'accuracy': 0.37066666666666664}


Epoch 36/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 36/60, Loss: 10.182174682617188, Validation Accuracy: {'accuracy': 0.39066666666666666}


Epoch 37/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 37/60, Loss: 14.254904747009277, Validation Accuracy: {'accuracy': 0.3933333333333333}


Epoch 38/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 38/60, Loss: 14.070018768310547, Validation Accuracy: {'accuracy': 0.38166666666666665}


Epoch 39/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 39/60, Loss: 18.82832908630371, Validation Accuracy: {'accuracy': 0.3963333333333333}


Epoch 40/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 40/60, Loss: 16.76792335510254, Validation Accuracy: {'accuracy': 0.374}


Epoch 41/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 41/60, Loss: 11.992900848388672, Validation Accuracy: {'accuracy': 0.38066666666666665}


Epoch 42/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 42/60, Loss: 12.539794921875, Validation Accuracy: {'accuracy': 0.374}


Epoch 43/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 43/60, Loss: 19.90580177307129, Validation Accuracy: {'accuracy': 0.3496666666666667}


Epoch 44/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 44/60, Loss: 16.561410903930664, Validation Accuracy: {'accuracy': 0.373}


Epoch 45/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 45/60, Loss: 24.137067794799805, Validation Accuracy: {'accuracy': 0.339}


Epoch 46/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 46/60, Loss: 15.249022483825684, Validation Accuracy: {'accuracy': 0.37}


Epoch 47/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 47/60, Loss: 16.377527236938477, Validation Accuracy: {'accuracy': 0.40166666666666667}


Epoch 48/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 48/60, Loss: 21.770893096923828, Validation Accuracy: {'accuracy': 0.4063333333333333}


Epoch 49/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 49/60, Loss: 18.278390884399414, Validation Accuracy: {'accuracy': 0.38766666666666666}


Epoch 50/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 50/60, Loss: 17.770038604736328, Validation Accuracy: {'accuracy': 0.372}


Epoch 51/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 51/60, Loss: 14.175053596496582, Validation Accuracy: {'accuracy': 0.425}


Epoch 52/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 52/60, Loss: 14.829466819763184, Validation Accuracy: {'accuracy': 0.4036666666666667}


Epoch 53/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 53/60, Loss: 14.541029930114746, Validation Accuracy: {'accuracy': 0.42533333333333334}


Epoch 54/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 54/60, Loss: 15.867446899414062, Validation Accuracy: {'accuracy': 0.424}


Epoch 55/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 55/60, Loss: 14.792837142944336, Validation Accuracy: {'accuracy': 0.38466666666666666}


Epoch 56/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 56/60, Loss: 10.047399520874023, Validation Accuracy: {'accuracy': 0.432}


Epoch 57/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 57/60, Loss: 29.417081832885742, Validation Accuracy: {'accuracy': 0.4066666666666667}


Epoch 58/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 58/60, Loss: 20.79485511779785, Validation Accuracy: {'accuracy': 0.366}


Epoch 59/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 59/60, Loss: 11.908674240112305, Validation Accuracy: {'accuracy': 0.3933333333333333}


Epoch 60/ 60:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 60/60, Loss: 15.43360424041748, Validation Accuracy: {'accuracy': 0.376}


  0%|          | 0/1563 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▂▃▃▅▄▅▄▅▆▅▆▆▆▆▆▇▇▇▇▆▆▆▇▇▇▇▇▆▇▇▇▇▇▇█▇█▇▇
Loss/Train,█▂▃▄▆▄▅▄▃▃▇▆▄▃█▃▃▅▂▄▃▆▄▂▃▃▄▂▅▄▃▄▅▄▃▃▃▁▆▃
Loss/Validation,▆▆▆█▄▆▄▃▄▁▄▂▃▃▂▃▄▁▂▂▃▅▄▅▂▄▅▂█▅▄▁▃▅▄▂▆▁▆▅

0,1
Accuracy/Test,0.3884
Accuracy/Validation,0.376
Loss/Train,15.4336
Loss/Validation,13.64849
