In [1]:
import os
import sys
import torch
import random
import configs
import numpy as np
import torch.nn as nn
from PIL import Image
import tensorflow as tf
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

2023-11-24 12:25:00.874935: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-24 12:25:00.875066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-24 12:25:00.875308: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-24 12:25:00.922487: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
configs.set_seed(42)

In [3]:
device = configs.set_device(3)

There are 8 GPU(s) available.
We will use the GPU: NVIDIA A100-PCIE-40GB


In [None]:
from tensorflow.keras.datasets import cifar10

(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()


image_paths = []
labels = []

save_dir = 'cifar10_images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i in range(len(train_images)):
    image_path = os.path.join(save_dir, f"train_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, train_images[i])
    image_paths.append(image_path)
    labels.append(train_labels[i][0])


for i in range(len(test_images)):
    image_path = os.path.join(save_dir, f"test_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, test_images[i])
    image_paths.append(image_path)
    labels.append(test_labels[i][0])

print(len(train_images), "\n")
print(len(test_images), "\n")
print(len(image_paths))

In [None]:
classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}

In [None]:
titles = []
for im in tqdm(range(len(image_paths))):
    titles.append(classes[labels[im]])
    
assert len(titles) == len(labels) == len(image_paths)
assert titles[59000] == 'ship'

In [None]:
def remove_prefixes(strings):
    prefixes = ['a', 'an', 'the']
    result = []

    for string in strings:
        words = string.split()
        if words[0].lower() in prefixes:
            result.append(' '.join(words[1:]))
        else:
            result.append(string)

    return result

with open("conceptnet_cifar10_filtered_new.txt", "r") as f:
    concepts = f.read().lower().split("\n")
    concepts = remove_prefixes(concepts)

In [8]:
class CLIPDataset():
    def __init__(self, list_image_path, list_txt):
        self.image_path = list_image_path
        self.title  = list_txt

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx])
        title = self.title[idx]
        return image, title


def collate_fn(batch):
    return {
        'image': [x[0] for x in batch],
        'title': [x[1] for x in batch]
    }

dataset = CLIPDataset(list_image_path=image_paths, list_txt=labels) # but it can be with <<titles>> to get textual annotations for class labels
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)

assert len(dataset) == len(image_paths)
print("Dataset size: {}".format(len(dataset)), "\n")

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])
print("Train set: {}".format(len(train_dataset)), "\n")
print("Validation set: {}".format(len(val_dataset)), "\n")
print("Test set: {}".format(len(test_dataset)))

Dataset size: 60000 

Train set: 50000 

Validation set: 3000 

Test set: 7000


In [9]:
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer

def preprocess_loader(loader, concepts: list):
    preprocessed_batches = []
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    for batch in tqdm(loader):
        preprocessed_batch = preprocess_batch(batch, processor, concepts)
        preprocessed_batches.append(preprocessed_batch)
    return preprocessed_batches

def preprocess_batch(batch, processor, concepts: list):
    return processor(text=concepts, images=batch['image'], return_tensors="pt", padding=True), batch['title']

In [10]:
train_loader_preprocessed = preprocess_loader(train_loader, concepts)
val_loader_preprocessed = preprocess_loader(val_loader, concepts)
test_loader_preprocessed = preprocess_loader(test_loader, concepts)

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

In [11]:
import transformers

class TuningCLIPViThead(nn.Module):
    def __init__(self, concepts: list, classes: dict, model_name: str="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip = transformers.CLIPModel.from_pretrained(model_name)
        self.processor = transformers.CLIPProcessor.from_pretrained(model_name)
        for param in self.clip.text_model.parameters():
            param.requires_grad = False
        for name, param in self.clip.vision_model.named_parameters():
            if 'self_attn.v_proj' in name or 'self_attn.q_proj' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.head = nn.Linear(len(concepts), len(classes), bias=False)
        
    def forward(self, **batch):
        out = self.clip(**batch).logits_per_image
        return self.head(out)

In [43]:
model = TuningCLIPViThead(concepts, classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

Number of trainable parameters is: 14830769


In [44]:
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

In [12]:
import datasets

metric = datasets.load_metric("accuracy")

  metric = datasets.load_metric("accuracy")


In [13]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def val_loss_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    val_losses = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        similarity = model.clip(**inputs).logits_per_image.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        loss = criterion(logits, targets.to(device)) + clip_loss(similarity)
        val_losses.append(loss.item())
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)

    val_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    avg_val_loss = sum(val_losses) / len(val_losses)
    
    return val_accuracy, avg_val_loss

@torch.no_grad()
@torch.cuda.amp.autocast()
def get_test_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        targets = torch.tensor(labels, dtype=torch.long)
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)
        
    test_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    
    return test_accuracy

In [14]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msemenov-andrei-v[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [13]:
with open("all_concepts.txt", "r") as f:
    all_concepts = f.read().lower().split("\n")
    all_concepts = remove_prefixes(all_concepts)

train_loader_preprocessed = preprocess_loader(train_loader, all_concepts)
val_loader_preprocessed = preprocess_loader(val_loader, all_concepts)
test_loader_preprocessed = preprocess_loader(test_loader, all_concepts)

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

## Bigger Concepts set

In [22]:
model = TuningCLIPViThead(all_concepts, classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 14880079


In [23]:
model.to(device)
run = wandb.init(project='cifar10-ViT-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/15, Loss: 2.32181715965271, Validation Accuracy: {'accuracy': 0.10466666666666667}


Epoch 2/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/15, Loss: 2.2297825813293457, Validation Accuracy: {'accuracy': 0.156}


Epoch 3/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/15, Loss: 2.3721611499786377, Validation Accuracy: {'accuracy': 0.148}


Epoch 4/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/15, Loss: 2.2155096530914307, Validation Accuracy: {'accuracy': 0.252}


Epoch 5/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/15, Loss: 2.334630250930786, Validation Accuracy: {'accuracy': 0.2743333333333333}


Epoch 6/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/15, Loss: 2.1388494968414307, Validation Accuracy: {'accuracy': 0.27066666666666667}


Epoch 7/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/15, Loss: 1.8941729068756104, Validation Accuracy: {'accuracy': 0.33166666666666667}


Epoch 8/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/15, Loss: 1.567327618598938, Validation Accuracy: {'accuracy': 0.381}


Epoch 9/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/15, Loss: 1.6679322719573975, Validation Accuracy: {'accuracy': 0.466}


Epoch 10/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/15, Loss: 1.4267778396606445, Validation Accuracy: {'accuracy': 0.481}


Epoch 11/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/15, Loss: 1.18190598487854, Validation Accuracy: {'accuracy': 0.5016666666666667}


Epoch 12/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/15, Loss: 1.3190885782241821, Validation Accuracy: {'accuracy': 0.5046666666666667}


Epoch 13/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/15, Loss: 1.437595248222351, Validation Accuracy: {'accuracy': 0.5276666666666666}


Epoch 14/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/15, Loss: 1.0873725414276123, Validation Accuracy: {'accuracy': 0.5326666666666666}


Epoch 15/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/15, Loss: 1.0004063844680786, Validation Accuracy: {'accuracy': 0.5353333333333333}


  0%|          | 0/1563 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▂▂▃▄▄▅▅▇▇▇████
Loss/Train,█▇█▇█▇▆▄▄▃▂▃▃▁▁
Loss/Validation,▇▅█▄▄▄▃▃▂▁▁▁▁▁▁

0,1
Accuracy/Test,0.5934
Accuracy/Validation,0.53533
Loss/Train,1.00041
Loss/Validation,1.29259


In [42]:
torch.cuda.empty_cache()

test with another batch sizes

In [23]:
class CLIPDataset():
    def __init__(self, list_image_path, list_txt):
        self.image_path = list_image_path
        self.title  = list_txt

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx])
        title = self.title[idx]
        return image, title


def collate_fn(batch):
    return {
        'image': [x[0] for x in batch],
        'title': [x[1] for x in batch]
    }

dataset = CLIPDataset(list_image_path=image_paths, list_txt=labels) # but it can be with <<titles>> to get textual annotations for class labels
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn, pin_memory=True)

assert len(dataset) == len(image_paths)
print("Dataset size: {}".format(len(dataset)), "\n")

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])
print("Train set: {}".format(len(train_dataset)), "\n")
print("Validation set: {}".format(len(val_dataset)), "\n")
print("Test set: {}".format(len(test_dataset)))

Dataset size: 60000 

Train set: 50000 

Validation set: 3000 

Test set: 7000


In [24]:
train_loader_preprocessed = preprocess_loader(train_loader, concepts)
val_loader_preprocessed = preprocess_loader(val_loader, concepts)
test_loader_preprocessed = preprocess_loader(test_loader, concepts)

  0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

In [28]:
model = TuningCLIPViThead(concepts, classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 14830769


In [29]:
model.to(device)
run = wandb.init(project='cifar10-ViT-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 1/15, Loss: 1.2912063598632812, Validation Accuracy: {'accuracy': 0.31066666666666665}


Epoch 2/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 2/15, Loss: 0.978803813457489, Validation Accuracy: {'accuracy': 0.402}


Epoch 3/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 3/15, Loss: 0.6990076899528503, Validation Accuracy: {'accuracy': 0.423}


Epoch 4/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 4/15, Loss: 0.8373039364814758, Validation Accuracy: {'accuracy': 0.431}


Epoch 5/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 5/15, Loss: 0.7066726684570312, Validation Accuracy: {'accuracy': 0.44033333333333335}


Epoch 6/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 6/15, Loss: 0.8990509510040283, Validation Accuracy: {'accuracy': 0.451}


Epoch 7/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 7/15, Loss: 1.0511701107025146, Validation Accuracy: {'accuracy': 0.453}


Epoch 8/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 8/15, Loss: 1.11309015750885, Validation Accuracy: {'accuracy': 0.454}


Epoch 9/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 9/15, Loss: 0.7665904760360718, Validation Accuracy: {'accuracy': 0.449}


Epoch 10/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 10/15, Loss: 0.6328774094581604, Validation Accuracy: {'accuracy': 0.45466666666666666}


Epoch 11/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 11/15, Loss: 0.6267906427383423, Validation Accuracy: {'accuracy': 0.44666666666666666}


Epoch 12/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 12/15, Loss: 0.9109774827957153, Validation Accuracy: {'accuracy': 0.45266666666666666}


Epoch 13/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 13/15, Loss: 0.6648542881011963, Validation Accuracy: {'accuracy': 0.448}


Epoch 14/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 14/15, Loss: 0.8962976932525635, Validation Accuracy: {'accuracy': 0.4623333333333333}


Epoch 15/ 15:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Epoch 15/15, Loss: 0.8846495151519775, Validation Accuracy: {'accuracy': 0.459}


  0%|          | 0/25000 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▅▆▇▇▇██▇█▇█▇██
Loss/Train,█▅▂▃▂▄▅▆▂▁▁▄▁▄▄
Loss/Validation,█▄▃▂▂▂▂▂▁▁▂▁▂▁▁

0,1
Accuracy/Test,0.47488
Accuracy/Validation,0.459
Loss/Train,0.88465
Loss/Validation,1.43594


## Test with two losses

In [18]:
torch.cuda.empty_cache()

In [19]:
model = TuningCLIPViThead(concepts, classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 14830769


In [20]:
def contrastive_loss(logits, dim: int):
    neg_ce = -torch.diag(F.log_softmax(logits, dim=dim))
    return neg_ce.mean() # set back to -

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    """
    Args:
        similarity: is equal to logits_per_image
    """
    caption_loss = contrastive_loss(similarity, dim=0)
    image_loss = contrastive_loss(similarity, dim=1)
    return (caption_loss + image_loss) / 2.0

In [21]:
model.to(device)
run = wandb.init(project='cifar10-ViT-head', entity='semenov-andrei-v')

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        similarity = model.clip(**inputs).logits_per_image.to(device)
        loss = criterion(logits, targets.to(device)) + clip_loss(similarity)
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, test_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/20, Training Loss: 5.219966888427734, Validation Loss: 5.249035444665463, Validation Accuracy: {'accuracy': 0.309}


Epoch 2/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/20, Training Loss: 5.029921531677246, Validation Loss: 4.92404542070754, Validation Accuracy: {'accuracy': 0.4693333333333333}


Epoch 3/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/20, Training Loss: 4.396029949188232, Validation Loss: 4.826614182046119, Validation Accuracy: {'accuracy': 0.5033333333333333}


Epoch 4/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/20, Training Loss: 4.2886528968811035, Validation Loss: 4.7818102633699455, Validation Accuracy: {'accuracy': 0.528}


Epoch 5/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/20, Training Loss: 4.263652801513672, Validation Loss: 4.778582912810306, Validation Accuracy: {'accuracy': 0.543}


Epoch 6/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/20, Training Loss: 4.178359508514404, Validation Loss: 4.786200447285429, Validation Accuracy: {'accuracy': 0.546}


Epoch 7/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/20, Training Loss: 4.066055774688721, Validation Loss: 4.7960904506926845, Validation Accuracy: {'accuracy': 0.5453333333333333}


Epoch 8/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/20, Training Loss: 3.829378366470337, Validation Loss: 4.792938800568276, Validation Accuracy: {'accuracy': 0.547}


Epoch 9/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/20, Training Loss: 3.9054856300354004, Validation Loss: 4.742161329756391, Validation Accuracy: {'accuracy': 0.565}


Epoch 10/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/20, Training Loss: 3.94002366065979, Validation Loss: 4.760852240501566, Validation Accuracy: {'accuracy': 0.5666666666666667}


Epoch 11/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/20, Training Loss: 3.91104793548584, Validation Loss: 4.761262954549586, Validation Accuracy: {'accuracy': 0.564}


Epoch 12/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/20, Training Loss: 3.886260986328125, Validation Loss: 4.801346322323414, Validation Accuracy: {'accuracy': 0.561}


Epoch 13/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/20, Training Loss: 3.889329433441162, Validation Loss: 4.737304616481699, Validation Accuracy: {'accuracy': 0.566}


Epoch 14/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/20, Training Loss: 3.950256109237671, Validation Loss: 4.749941952685092, Validation Accuracy: {'accuracy': 0.5556666666666666}


Epoch 15/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/20, Training Loss: 3.696226119995117, Validation Loss: 4.77922766259376, Validation Accuracy: {'accuracy': 0.5743333333333334}


Epoch 19/ 20:   0%|          | 0/1563 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

