In [1]:
import os
import sys
import torch
import random
import configs
import numpy as np
import transformers
import torch.nn as nn
from PIL import Image
import tensorflow as tf
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

2023-12-05 22:37:47.902566: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-05 22:37:47.902647: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-05 22:37:47.902667: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-05 22:37:47.910422: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
configs.set_seed(42)

In [3]:
device = configs.set_device(2)

There are 8 GPU(s) available.
We will use the GPU: NVIDIA A100-SXM4-80GB


## Data

In [5]:
from tensorflow.keras.datasets import cifar10

(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()


image_paths = []
labels = []

save_dir = 'cifar10_images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for i in range(len(train_images)):
    image_path = os.path.join(save_dir, f"train_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, train_images[i])
    image_paths.append(image_path)
    labels.append(train_labels[i][0])


for i in range(len(test_images)):
    image_path = os.path.join(save_dir, f"test_image_{i}.jpg")
    tf.keras.preprocessing.image.save_img(image_path, test_images[i])
    image_paths.append(image_path)
    labels.append(test_labels[i][0])

print(len(train_images), "\n")
print(len(test_images), "\n")
print(len(image_paths))

50000 

10000 

60000


In [6]:
classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}

In [7]:
titles = []
for im in tqdm(range(len(image_paths))):
    titles.append(classes[labels[im]])

  0%|          | 0/60000 [00:00<?, ?it/s]

In [8]:
assert len(titles) == len(labels) == len(image_paths)
assert titles[59000] == 'ship'

In [9]:
def remove_prefixes(strings):
    prefixes = ['a', 'an', 'the']
    result = []

    for string in strings:
        words = string.split()
        if words[0].lower() in prefixes:
            result.append(' '.join(words[1:]))
        else:
            result.append(string)

    return result

with open("conceptnet_cifar10_filtered_new.txt", "r") as f:
    concepts = f.read().lower().split("\n")
    concepts = remove_prefixes(concepts)

In [10]:
class CLIPDataset():
    def __init__(self, list_image_path, list_txt):
        self.image_path = list_image_path
        self.title  = list_txt

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx])
        title = self.title[idx]
        return image, title


def collate_fn(batch):
    return {
        'image': [x[0] for x in batch],
        'title': [x[1] for x in batch]
    }

dataset = CLIPDataset(list_image_path=image_paths, list_txt=labels) # but it can be with <<titles>> to get textual annotations for class labels
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, pin_memory=True)

assert len(dataset) == len(image_paths)
print("Dataset size: {}".format(len(dataset)), "\n")

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [50000, 3000, 7000])
print("Train set: {}".format(len(train_dataset)), "\n")
print("Validation set: {}".format(len(val_dataset)), "\n")
print("Test set: {}".format(len(test_dataset)))

Dataset size: 60000 

Train set: 50000 

Validation set: 3000 

Test set: 7000


In [11]:
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer

def preprocess_loader(loader, concepts: list):
    preprocessed_batches = []
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    for batch in tqdm(loader):
        preprocessed_batch = preprocess_batch(batch, processor, concepts)
        preprocessed_batches.append(preprocessed_batch)
    return preprocessed_batches

def preprocess_batch(batch, processor, concepts: list):
    return processor(text=concepts, images=batch['image'], return_tensors="pt", padding=True), batch['title']

In [12]:
train_loader_preprocessed = preprocess_loader(train_loader, concepts)
val_loader_preprocessed = preprocess_loader(val_loader, concepts)
test_loader_preprocessed = preprocess_loader(test_loader, concepts)

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

In [13]:
import transformers

class TuningCLIPhead(nn.Module):
    def __init__(self, model_name: str="openai/clip-vit-base-patch32", concepts: list=concepts, classes: dict=classes):
        super().__init__()
        self.clip = transformers.CLIPModel.from_pretrained(model_name)
        self.processor = transformers.CLIPProcessor.from_pretrained(model_name)
        for param in self.clip.parameters():
            param.requires_grad=False
        self.head = nn.Linear(len(concepts), len(classes), bias=False)

    def forward(self, **batch):
        out = self.clip(**batch).logits_per_image
        return self.head(out)

## Training simple head

In [160]:
model = TuningCLIPhead()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

Number of trainable parameters is: 1200


In [161]:
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

In [14]:
import datasets

metric = datasets.load_metric("accuracy")

  metric = datasets.load_metric("accuracy")


In [15]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def val_loss_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    val_losses = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        similarity = model.clip(**inputs).logits_per_image.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        loss = criterion(logits, targets.to(device)) + clip_loss(similarity)
        val_losses.append(loss.item())
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)

    val_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    avg_val_loss = sum(val_losses) / len(val_losses)
    
    return val_accuracy, avg_val_loss

In [16]:
@torch.no_grad()
@torch.cuda.amp.autocast()
def get_test_accuracy(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    for batch in tqdm(loader):
        inputs, labels = batch
        inputs = inputs.to(device)
        logits = model(**inputs).squeeze(0)
        targets = torch.tensor(labels, dtype=torch.long)
        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels)
        
    test_accuracy = metric.compute(predictions=all_preds, references=all_labels)
    
    return test_accuracy

In [17]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33msemenov-andrei-v[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [162]:
model.to(device)
run = wandb.init(project='cifar10-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/15, Loss: 1.3807625770568848, Validation Accuracy: {'accuracy': 0.6523333333333333}


Epoch 2/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/15, Loss: 1.3001844882965088, Validation Accuracy: {'accuracy': 0.698}


Epoch 3/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/15, Loss: 1.3183488845825195, Validation Accuracy: {'accuracy': 0.707}


Epoch 4/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/15, Loss: 1.322094202041626, Validation Accuracy: {'accuracy': 0.7076666666666667}


Epoch 5/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/15, Loss: 1.3170406818389893, Validation Accuracy: {'accuracy': 0.7036666666666667}


Epoch 6/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/15, Loss: 1.3097145557403564, Validation Accuracy: {'accuracy': 0.7046666666666667}


Epoch 7/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/15, Loss: 1.3022809028625488, Validation Accuracy: {'accuracy': 0.7066666666666667}


Epoch 8/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/15, Loss: 1.2960307598114014, Validation Accuracy: {'accuracy': 0.7146666666666667}


Epoch 9/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/15, Loss: 1.291098713874817, Validation Accuracy: {'accuracy': 0.724}


Epoch 10/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/15, Loss: 1.2872544527053833, Validation Accuracy: {'accuracy': 0.7356666666666667}


Epoch 11/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/15, Loss: 1.284134030342102, Validation Accuracy: {'accuracy': 0.7446666666666667}


Epoch 12/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/15, Loss: 1.2816400527954102, Validation Accuracy: {'accuracy': 0.7506666666666667}


Epoch 13/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/15, Loss: 1.2797330617904663, Validation Accuracy: {'accuracy': 0.7553333333333333}


Epoch 14/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/15, Loss: 1.2782061100006104, Validation Accuracy: {'accuracy': 0.7613333333333333}


Epoch 15/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/15, Loss: 1.276811957359314, Validation Accuracy: {'accuracy': 0.765}


  0%|          | 0/1563 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▄▄▄▄▄▄▅▅▆▇▇▇██
Loss/Train,█▃▄▄▄▃▃▂▂▂▁▁▁▁▁
Loss/Validation,█▄▄▄▅▅▅▄▄▃▃▂▂▁▁

0,1
Accuracy/Test,0.76038
Accuracy/Validation,0.765
Loss/Train,1.27681
Loss/Validation,0.70618


In [25]:
#during debugging
torch.cuda.empty_cache()

## Training smarter head

In [11]:
import math

class FastGeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3)))

In [12]:
class TuningCLIPsmarterHead(nn.Module):
    def __init__(self, model_name: str="openai/clip-vit-base-patch32", concepts: list=concepts, classes: dict=classes):
        super().__init__()
        self.clip = transformers.CLIPModel.from_pretrained(model_name)
        self.processor = transformers.CLIPProcessor.from_pretrained(model_name)
        for param in self.clip.parameters():
            param.requires_grad=False
        self.lin1 = nn.Linear(len(concepts), 60, bias=False)
        self.lin2 = nn.Linear(60, 30, bias=False)
        self.lin3 = nn.Linear(30, len(classes), bias=False)
        self.gelu = FastGeLU()

    def forward(self, **batch):
        out = self.clip(**batch).logits_per_image
        x = self.gelu(self.lin1(out))
        x = self.gelu(self.lin2(x))
        return self.lin3(x)

In [176]:
model = TuningCLIPsmarterHead()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

Number of trainable parameters is: 9300


In [177]:
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

In [None]:
model.to(device)
run = wandb.init(project='cifar10-smarter-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/15, Loss: 1.2982006072998047, Validation Accuracy: {'accuracy': 0.594}


Epoch 2/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/15, Loss: 1.3054379224777222, Validation Accuracy: {'accuracy': 0.6633333333333333}


Epoch 3/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/15, Loss: 1.2713488340377808, Validation Accuracy: {'accuracy': 0.7136666666666667}


Epoch 4/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/15, Loss: 1.3260315656661987, Validation Accuracy: {'accuracy': 0.722}


Epoch 5/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/15, Loss: 1.3058655261993408, Validation Accuracy: {'accuracy': 0.727}


Epoch 6/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/15, Loss: 1.2985113859176636, Validation Accuracy: {'accuracy': 0.727}


Epoch 7/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

# Another concepts set

In [13]:
with open("all_concepts.txt", "r") as f:
    all_concepts = f.read().lower().split("\n")
    all_concepts = remove_prefixes(all_concepts)

In [None]:
train_loader_preprocessed = preprocess_loader(train_loader, all_concepts)
val_loader_preprocessed = preprocess_loader(val_loader, all_concepts)
test_loader_preprocessed = preprocess_loader(test_loader, all_concepts)

## Traing simple head

In [26]:
model = TuningCLIPhead(concepts=all_concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 50510


In [None]:
model.to(device)
run = wandb.init(project='cifar10-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

## Training smarter head

In [23]:
model = TuningCLIPsmarterHead(concepts=all_concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 305160


In [24]:
model.to(device)
run = wandb.init(project='cifar10-smarter-head', entity='semenov-andrei-v')

num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/15, Loss: 2.3025853633880615, Validation Accuracy: {'accuracy': 0.09266666666666666}


Epoch 2/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/15, Loss: 2.3025853633880615, Validation Accuracy: {'accuracy': 0.09266666666666666}


Epoch 3/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/15, Loss: 2.3025853633880615, Validation Accuracy: {'accuracy': 0.09266666666666666}


Epoch 4/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/15, Loss: 2.3025853633880615, Validation Accuracy: {'accuracy': 0.09266666666666666}


Epoch 5/ 15:   0%|          | 0/1563 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Custom concepts set

In [15]:
custom_concepts = ['plane', 'airliner', 'propeller', 'monoplane', 'fuselage', 'jet', 'car', 'vehicle', 'passenger', 'internal combustion engine', 'minivan', 'sedan',
                  'parrot', 'wing', 'passerine', 'fowl', 'albatross', 'geese', 'kiwi', 'syrinx', 'gull', 'lion', 'tiger', 'leopard', 'pet', 'jaguar', 'felis', 'rat', 'rabbit',
                  'elk', 'moose', 'reindeer', 'antelope', 'ruminant', 'antler', 'pig', 'sheep', 'alps', 'eurasian elk', 'foxes', 'puppy', 'cur', 'wolf', 'tail', 'great dane', 'poodle', 'hound',
                  'canid', 'corgi', 'pawl', 'toad', 'amphibian', 'egg', 'gill', 'lizard', 'tongue', 'carnivore', 'pony', 'foal', 'thoroughbred', 'hack', 'cartilage', 'donkey', 'bridle', 'boat',
                  'ferry', 'submarine', 'vessel', 'cargo', 'sail', 'sea', 'barque', 'schooner', 'travel', 'galley', 'water', 'watercraft', 'lorry', 'van', 'wagon', 'suv', 'jeep', 'forklift', 
                  'bumper', 'trailer', 'driver', 'boxcar', 'flatcar', 'motorbikr', 'oldsmobile']

## Training simple head

In [17]:
train_loader_preprocessed = preprocess_loader(train_loader, custom_concepts)
val_loader_preprocessed = preprocess_loader(val_loader, custom_concepts)
test_loader_preprocessed = preprocess_loader(test_loader, custom_concepts)

  0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/219 [00:00<?, ?it/s]

In [33]:
model = TuningCLIPhead(concepts=custom_concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 900


In [34]:
model.to(device)
run = wandb.init(project='cifar10-head', entity='semenov-andrei-v')

num_epochs = 40
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/40, Loss: 0.8347995281219482, Validation Accuracy: {'accuracy': 0.689}


Epoch 2/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/40, Loss: 0.6357260942459106, Validation Accuracy: {'accuracy': 0.733}


Epoch 3/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/40, Loss: 0.5521941184997559, Validation Accuracy: {'accuracy': 0.748}


Epoch 4/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/40, Loss: 0.5106992721557617, Validation Accuracy: {'accuracy': 0.7523333333333333}


Epoch 5/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/40, Loss: 0.48985129594802856, Validation Accuracy: {'accuracy': 0.7593333333333333}


Epoch 6/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/40, Loss: 0.4793414771556854, Validation Accuracy: {'accuracy': 0.7646666666666667}


Epoch 7/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/40, Loss: 0.47395530343055725, Validation Accuracy: {'accuracy': 0.77}


Epoch 8/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/40, Loss: 0.47115492820739746, Validation Accuracy: {'accuracy': 0.7733333333333333}


Epoch 9/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/40, Loss: 0.4697237014770508, Validation Accuracy: {'accuracy': 0.7803333333333333}


Epoch 10/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/40, Loss: 0.4690609276294708, Validation Accuracy: {'accuracy': 0.7843333333333333}


Epoch 11/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/40, Loss: 0.46886470913887024, Validation Accuracy: {'accuracy': 0.7853333333333333}


Epoch 12/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/40, Loss: 0.46897709369659424, Validation Accuracy: {'accuracy': 0.7883333333333333}


Epoch 13/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/40, Loss: 0.46931132674217224, Validation Accuracy: {'accuracy': 0.7913333333333333}


Epoch 14/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/40, Loss: 0.46981263160705566, Validation Accuracy: {'accuracy': 0.7926666666666666}


Epoch 15/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/40, Loss: 0.4704422652721405, Validation Accuracy: {'accuracy': 0.7946666666666666}


Epoch 16/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 16/40, Loss: 0.47117435932159424, Validation Accuracy: {'accuracy': 0.7933333333333333}


Epoch 17/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 17/40, Loss: 0.4719836413860321, Validation Accuracy: {'accuracy': 0.7946666666666666}


Epoch 18/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/40, Loss: 0.4728516638278961, Validation Accuracy: {'accuracy': 0.7966666666666666}


Epoch 19/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 19/40, Loss: 0.47376224398612976, Validation Accuracy: {'accuracy': 0.7976666666666666}


Epoch 20/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 20/40, Loss: 0.47470447421073914, Validation Accuracy: {'accuracy': 0.797}


Epoch 21/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 21/40, Loss: 0.4756675660610199, Validation Accuracy: {'accuracy': 0.7986666666666666}


Epoch 22/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 22/40, Loss: 0.4766400158405304, Validation Accuracy: {'accuracy': 0.7993333333333333}


Epoch 23/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 23/40, Loss: 0.4776172637939453, Validation Accuracy: {'accuracy': 0.7993333333333333}


Epoch 24/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 24/40, Loss: 0.4785911440849304, Validation Accuracy: {'accuracy': 0.8003333333333333}


Epoch 25/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 25/40, Loss: 0.4795585572719574, Validation Accuracy: {'accuracy': 0.801}


Epoch 26/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 26/40, Loss: 0.4805152416229248, Validation Accuracy: {'accuracy': 0.8026666666666666}


Epoch 27/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 27/40, Loss: 0.4814539849758148, Validation Accuracy: {'accuracy': 0.8036666666666666}


Epoch 28/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 28/40, Loss: 0.482377827167511, Validation Accuracy: {'accuracy': 0.8046666666666666}


Epoch 29/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 29/40, Loss: 0.4832814931869507, Validation Accuracy: {'accuracy': 0.8053333333333333}


Epoch 30/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 30/40, Loss: 0.48416125774383545, Validation Accuracy: {'accuracy': 0.8053333333333333}


Epoch 31/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 31/40, Loss: 0.4850197434425354, Validation Accuracy: {'accuracy': 0.8056666666666666}


Epoch 32/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 32/40, Loss: 0.48585203289985657, Validation Accuracy: {'accuracy': 0.806}


Epoch 33/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 33/40, Loss: 0.48666080832481384, Validation Accuracy: {'accuracy': 0.8076666666666666}


Epoch 34/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 34/40, Loss: 0.4874415099620819, Validation Accuracy: {'accuracy': 0.8076666666666666}


Epoch 35/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 35/40, Loss: 0.4881984293460846, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 36/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 36/40, Loss: 0.4889273941516876, Validation Accuracy: {'accuracy': 0.808}


Epoch 37/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 37/40, Loss: 0.4896291494369507, Validation Accuracy: {'accuracy': 0.809}


Epoch 38/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 38/40, Loss: 0.4903052747249603, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 39/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 39/40, Loss: 0.4909560978412628, Validation Accuracy: {'accuracy': 0.809}


Epoch 40/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 40/40, Loss: 0.49157682061195374, Validation Accuracy: {'accuracy': 0.81}


  0%|          | 0/1563 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy/Test,▁
Accuracy/Validation,▁▄▄▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
Loss/Train,█▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/Validation,█▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Accuracy/Test,0.817
Accuracy/Validation,0.81
Loss/Train,0.49158
Loss/Validation,0.55636


## Training smarter head

In [36]:
model = TuningCLIPsmarterHead(concepts=custom_concepts)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 7500


In [None]:
model.to(device)
run = wandb.init(project='cifar10-smarter-head', entity='semenov-andrei-v')

num_epochs = 40
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        
        loss = criterion(logits, targets.to(device))
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, train_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/40, Loss: 0.7203641533851624, Validation Accuracy: {'accuracy': 0.6846666666666666}


Epoch 2/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/40, Loss: 0.5950812101364136, Validation Accuracy: {'accuracy': 0.7083333333333334}


Epoch 3/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/40, Loss: 0.6003626585006714, Validation Accuracy: {'accuracy': 0.7226666666666667}


Epoch 4/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/40, Loss: 0.5956344604492188, Validation Accuracy: {'accuracy': 0.74}


Epoch 5/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/40, Loss: 0.5947118401527405, Validation Accuracy: {'accuracy': 0.76}


Epoch 6/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/40, Loss: 0.5945807099342346, Validation Accuracy: {'accuracy': 0.7756666666666666}


Epoch 7/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/40, Loss: 0.5907315015792847, Validation Accuracy: {'accuracy': 0.783}


Epoch 8/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/40, Loss: 0.5827414989471436, Validation Accuracy: {'accuracy': 0.7866666666666666}


Epoch 9/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/40, Loss: 0.5767456293106079, Validation Accuracy: {'accuracy': 0.7903333333333333}


Epoch 10/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/40, Loss: 0.5689659118652344, Validation Accuracy: {'accuracy': 0.7906666666666666}


Epoch 11/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/40, Loss: 0.5637226104736328, Validation Accuracy: {'accuracy': 0.791}


Epoch 12/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/40, Loss: 0.543840229511261, Validation Accuracy: {'accuracy': 0.791}


Epoch 13/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/40, Loss: 0.545541524887085, Validation Accuracy: {'accuracy': 0.7916666666666666}


Epoch 14/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/40, Loss: 0.5420042872428894, Validation Accuracy: {'accuracy': 0.7863333333333333}


Epoch 15/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/40, Loss: 0.5485018491744995, Validation Accuracy: {'accuracy': 0.7846666666666666}


Epoch 16/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 16/40, Loss: 0.5482114553451538, Validation Accuracy: {'accuracy': 0.7863333333333333}


Epoch 17/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 17/40, Loss: 0.5437273383140564, Validation Accuracy: {'accuracy': 0.7886666666666666}


Epoch 18/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/40, Loss: 0.537867546081543, Validation Accuracy: {'accuracy': 0.789}


Epoch 19/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 19/40, Loss: 0.550273060798645, Validation Accuracy: {'accuracy': 0.7853333333333333}


Epoch 20/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 20/40, Loss: 0.5396540760993958, Validation Accuracy: {'accuracy': 0.788}


Epoch 21/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 21/40, Loss: 0.5421332120895386, Validation Accuracy: {'accuracy': 0.7906666666666666}


Epoch 22/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 22/40, Loss: 0.5485175848007202, Validation Accuracy: {'accuracy': 0.7903333333333333}


Epoch 23/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 23/40, Loss: 0.5629677176475525, Validation Accuracy: {'accuracy': 0.7896666666666666}


Epoch 24/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 24/40, Loss: 0.5205003023147583, Validation Accuracy: {'accuracy': 0.7943333333333333}


Epoch 25/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 25/40, Loss: 0.5383443832397461, Validation Accuracy: {'accuracy': 0.791}


Epoch 26/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 26/40, Loss: 0.5389748215675354, Validation Accuracy: {'accuracy': 0.791}


Epoch 27/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 27/40, Loss: 0.5502119064331055, Validation Accuracy: {'accuracy': 0.798}


Epoch 28/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 28/40, Loss: 0.5465656518936157, Validation Accuracy: {'accuracy': 0.7896666666666666}


Epoch 29/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 29/40, Loss: 0.5391364693641663, Validation Accuracy: {'accuracy': 0.7936666666666666}


Epoch 30/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 30/40, Loss: 0.5412341952323914, Validation Accuracy: {'accuracy': 0.7966666666666666}


Epoch 31/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 31/40, Loss: 0.5223079323768616, Validation Accuracy: {'accuracy': 0.792}


Epoch 32/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 32/40, Loss: 0.5282832384109497, Validation Accuracy: {'accuracy': 0.7973333333333333}


Epoch 33/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 33/40, Loss: 0.5277262926101685, Validation Accuracy: {'accuracy': 0.8033333333333333}


Epoch 34/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 34/40, Loss: 0.5140519142150879, Validation Accuracy: {'accuracy': 0.805}


Epoch 35/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 35/40, Loss: 0.5540097951889038, Validation Accuracy: {'accuracy': 0.7926666666666666}


Epoch 36/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 36/40, Loss: 0.5384458303451538, Validation Accuracy: {'accuracy': 0.7933333333333333}


Epoch 37/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 37/40, Loss: 0.5375583171844482, Validation Accuracy: {'accuracy': 0.7893333333333333}


Epoch 38/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 38/40, Loss: 0.5673350095748901, Validation Accuracy: {'accuracy': 0.7956666666666666}


Epoch 39/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 39/40, Loss: 0.521165668964386, Validation Accuracy: {'accuracy': 0.8033333333333333}


Epoch 40/ 40:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 40/40, Loss: 0.5174022912979126, Validation Accuracy: {'accuracy': 0.8016666666666666}


  0%|          | 0/1563 [00:00<?, ?it/s]

# Simple head with two loss functions

In [35]:
torch.cuda.empty_cache()

In [36]:
model = TuningCLIPhead()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Number of trainable parameters is: {}".format(sum(p.numel() for p in model.parameters() if p.requires_grad == True)))

for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Number of trainable parameters is: 1200


new loss function

In [34]:
def contrastive_loss(logits, dim: int):
    neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
    return -neg_ce.mean() # set back to -

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    """
    Args:
        similarity: is equal to logits_per_image
    """
    caption_loss = contrastive_loss(similarity, dim=0)
    image_loss = contrastive_loss(similarity, dim=1)
    return (caption_loss + image_loss) / 2.0

In [37]:
model.to(device)
run = wandb.init(project='cifar10-head', entity='semenov-andrei-v')

num_epochs = 150
for epoch in range(num_epochs):
    model.train()
    for batch in tqdm(train_loader_preprocessed, desc=f"Epoch {epoch + 1}/ {num_epochs}"):
        optimizer.zero_grad()
        
        inputs, labels = batch
        inputs = inputs.to(device)
        targets = torch.tensor(labels, dtype=torch.long)
        logits = model(**inputs).squeeze(0)
        similarity = model.clip(**inputs).logits_per_image.to(device)
        loss = criterion(logits, targets.to(device)) + clip_loss(similarity)
        loss.backward()
        optimizer.step()
    
    model.eval()
    val_accuracy, avg_val_loss = val_loss_accuracy(model, val_loader_preprocessed)

    wandb.log({"Loss/Train": loss.item(), 
               "Loss/Validation": avg_val_loss, 
               "Accuracy/Validation": val_accuracy['accuracy']}
             )

    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}, Validation Accuracy: {val_accuracy}")

model.eval()
test_accuracy = get_test_accuracy(model, test_loader_preprocessed)
wandb.log({"Accuracy/Test": test_accuracy['accuracy']})

run.finish()

Epoch 1/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1/150, Training Loss: 6.287978172302246, Validation Loss: 6.629819722885781, Validation Accuracy: {'accuracy': 0.5966666666666667}


Epoch 2/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2/150, Training Loss: 6.169147491455078, Validation Loss: 6.458192480371354, Validation Accuracy: {'accuracy': 0.67}


Epoch 3/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3/150, Training Loss: 6.164817810058594, Validation Loss: 6.431430882595955, Validation Accuracy: {'accuracy': 0.68}


Epoch 4/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4/150, Training Loss: 6.168425559997559, Validation Loss: 6.431413366439495, Validation Accuracy: {'accuracy': 0.6833333333333333}


Epoch 5/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5/150, Training Loss: 6.171731948852539, Validation Loss: 6.432331962788359, Validation Accuracy: {'accuracy': 0.69}


Epoch 6/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6/150, Training Loss: 6.173544883728027, Validation Loss: 6.427800599564898, Validation Accuracy: {'accuracy': 0.697}


Epoch 7/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7/150, Training Loss: 6.174279689788818, Validation Loss: 6.413553643733897, Validation Accuracy: {'accuracy': 0.7013333333333334}


Epoch 8/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 8/150, Training Loss: 6.174312114715576, Validation Loss: 6.39849332545666, Validation Accuracy: {'accuracy': 0.7113333333333334}


Epoch 9/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 9/150, Training Loss: 6.173724174499512, Validation Loss: 6.3803260072748715, Validation Accuracy: {'accuracy': 0.72}


Epoch 10/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 10/150, Training Loss: 6.172296047210693, Validation Loss: 6.362395565560523, Validation Accuracy: {'accuracy': 0.7283333333333334}


Epoch 11/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 11/150, Training Loss: 6.169850826263428, Validation Loss: 6.3454364665011145, Validation Accuracy: {'accuracy': 0.734}


Epoch 12/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 12/150, Training Loss: 6.166464805603027, Validation Loss: 6.327026970843051, Validation Accuracy: {'accuracy': 0.7426666666666667}


Epoch 13/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 13/150, Training Loss: 6.162545204162598, Validation Loss: 6.313199611420327, Validation Accuracy: {'accuracy': 0.7446666666666667}


Epoch 14/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 14/150, Training Loss: 6.158557891845703, Validation Loss: 6.296669087511428, Validation Accuracy: {'accuracy': 0.7526666666666667}


Epoch 15/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 15/150, Training Loss: 6.154724597930908, Validation Loss: 6.2835266285754265, Validation Accuracy: {'accuracy': 0.759}


Epoch 16/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 16/150, Training Loss: 6.151095390319824, Validation Loss: 6.272848134345197, Validation Accuracy: {'accuracy': 0.765}


Epoch 17/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 17/150, Training Loss: 6.147675037384033, Validation Loss: 6.2620829673523595, Validation Accuracy: {'accuracy': 0.7693333333333333}


Epoch 18/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 18/150, Training Loss: 6.144469261169434, Validation Loss: 6.253180569790779, Validation Accuracy: {'accuracy': 0.771}


Epoch 19/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 19/150, Training Loss: 6.141480445861816, Validation Loss: 6.245834710750174, Validation Accuracy: {'accuracy': 0.7733333333333333}


Epoch 20/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 20/150, Training Loss: 6.138702392578125, Validation Loss: 6.239683034572195, Validation Accuracy: {'accuracy': 0.7763333333333333}


Epoch 21/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 21/150, Training Loss: 6.136127471923828, Validation Loss: 6.233277087515973, Validation Accuracy: {'accuracy': 0.779}


Epoch 22/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 22/150, Training Loss: 6.133737564086914, Validation Loss: 6.228047868038746, Validation Accuracy: {'accuracy': 0.7833333333333333}


Epoch 23/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 23/150, Training Loss: 6.1315202713012695, Validation Loss: 6.223702466234248, Validation Accuracy: {'accuracy': 0.7856666666666666}


Epoch 24/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 24/150, Training Loss: 6.129459381103516, Validation Loss: 6.2194253028707305, Validation Accuracy: {'accuracy': 0.7883333333333333}


Epoch 25/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 25/150, Training Loss: 6.127543926239014, Validation Loss: 6.215545294132639, Validation Accuracy: {'accuracy': 0.7893333333333333}


Epoch 26/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 26/150, Training Loss: 6.125760078430176, Validation Loss: 6.212597730311941, Validation Accuracy: {'accuracy': 0.7893333333333333}


Epoch 27/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 27/150, Training Loss: 6.124096870422363, Validation Loss: 6.208639251424911, Validation Accuracy: {'accuracy': 0.7916666666666666}


Epoch 28/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 28/150, Training Loss: 6.1225481033325195, Validation Loss: 6.205456099611648, Validation Accuracy: {'accuracy': 0.7926666666666666}


Epoch 29/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 29/150, Training Loss: 6.121098518371582, Validation Loss: 6.203381320263477, Validation Accuracy: {'accuracy': 0.7936666666666666}


Epoch 30/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 30/150, Training Loss: 6.119743824005127, Validation Loss: 6.200816164625452, Validation Accuracy: {'accuracy': 0.795}


Epoch 31/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 31/150, Training Loss: 6.118476390838623, Validation Loss: 6.1988503273497235, Validation Accuracy: {'accuracy': 0.7956666666666666}


Epoch 32/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 32/150, Training Loss: 6.117284297943115, Validation Loss: 6.196540107118323, Validation Accuracy: {'accuracy': 0.796}


Epoch 33/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 33/150, Training Loss: 6.116167068481445, Validation Loss: 6.193807754110782, Validation Accuracy: {'accuracy': 0.7966666666666666}


Epoch 34/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 34/150, Training Loss: 6.1151123046875, Validation Loss: 6.191469288886862, Validation Accuracy: {'accuracy': 0.7976666666666666}


Epoch 35/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 35/150, Training Loss: 6.1141133308410645, Validation Loss: 6.189499023112845, Validation Accuracy: {'accuracy': 0.7976666666666666}


Epoch 36/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 36/150, Training Loss: 6.113170146942139, Validation Loss: 6.1881983330909245, Validation Accuracy: {'accuracy': 0.7986666666666666}


Epoch 37/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 37/150, Training Loss: 6.112269401550293, Validation Loss: 6.18601771111184, Validation Accuracy: {'accuracy': 0.7996666666666666}


Epoch 38/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 38/150, Training Loss: 6.111414432525635, Validation Loss: 6.183973677614902, Validation Accuracy: {'accuracy': 0.7996666666666666}


Epoch 39/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 39/150, Training Loss: 6.110592842102051, Validation Loss: 6.182193583630501, Validation Accuracy: {'accuracy': 0.8}


Epoch 40/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 40/150, Training Loss: 6.109802722930908, Validation Loss: 6.181458331168966, Validation Accuracy: {'accuracy': 0.8}


Epoch 41/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 41/150, Training Loss: 6.109039306640625, Validation Loss: 6.1791959671264, Validation Accuracy: {'accuracy': 0.7993333333333333}


Epoch 42/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 42/150, Training Loss: 6.108302116394043, Validation Loss: 6.177256675476723, Validation Accuracy: {'accuracy': 0.8003333333333333}


Epoch 43/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 43/150, Training Loss: 6.107580184936523, Validation Loss: 6.17544842273631, Validation Accuracy: {'accuracy': 0.8016666666666666}


Epoch 44/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 44/150, Training Loss: 6.106870174407959, Validation Loss: 6.175246654672826, Validation Accuracy: {'accuracy': 0.801}


Epoch 45/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 45/150, Training Loss: 6.106175422668457, Validation Loss: 6.173736785320526, Validation Accuracy: {'accuracy': 0.8016666666666666}


Epoch 46/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 46/150, Training Loss: 6.105489730834961, Validation Loss: 6.1723438425267, Validation Accuracy: {'accuracy': 0.8026666666666666}


Epoch 47/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 47/150, Training Loss: 6.104801654815674, Validation Loss: 6.171290209952821, Validation Accuracy: {'accuracy': 0.8036666666666666}


Epoch 48/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 48/150, Training Loss: 6.104116916656494, Validation Loss: 6.1696518025499705, Validation Accuracy: {'accuracy': 0.8043333333333333}


Epoch 49/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 49/150, Training Loss: 6.1034345626831055, Validation Loss: 6.168808353708146, Validation Accuracy: {'accuracy': 0.804}


Epoch 50/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 50/150, Training Loss: 6.102749347686768, Validation Loss: 6.167710623842605, Validation Accuracy: {'accuracy': 0.8043333333333333}


Epoch 51/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 51/150, Training Loss: 6.102059364318848, Validation Loss: 6.167354015593833, Validation Accuracy: {'accuracy': 0.803}


Epoch 52/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 52/150, Training Loss: 6.1013593673706055, Validation Loss: 6.165634353110131, Validation Accuracy: {'accuracy': 0.8046666666666666}


Epoch 53/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 53/150, Training Loss: 6.100650787353516, Validation Loss: 6.164855916449365, Validation Accuracy: {'accuracy': 0.805}


Epoch 54/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 54/150, Training Loss: 6.099936008453369, Validation Loss: 6.163882925155315, Validation Accuracy: {'accuracy': 0.805}


Epoch 55/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 55/150, Training Loss: 6.099207401275635, Validation Loss: 6.1633491921932135, Validation Accuracy: {'accuracy': 0.8066666666666666}


Epoch 56/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 56/150, Training Loss: 6.098466873168945, Validation Loss: 6.161704575761836, Validation Accuracy: {'accuracy': 0.8073333333333333}


Epoch 57/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 57/150, Training Loss: 6.097711563110352, Validation Loss: 6.160367433060991, Validation Accuracy: {'accuracy': 0.8076666666666666}


Epoch 58/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 58/150, Training Loss: 6.096948146820068, Validation Loss: 6.160299970748577, Validation Accuracy: {'accuracy': 0.807}


Epoch 59/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 59/150, Training Loss: 6.096166610717773, Validation Loss: 6.158820294319315, Validation Accuracy: {'accuracy': 0.8066666666666666}


Epoch 60/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 60/150, Training Loss: 6.095372200012207, Validation Loss: 6.158727615437609, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 61/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 61/150, Training Loss: 6.094562530517578, Validation Loss: 6.156923288994647, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 62/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 62/150, Training Loss: 6.093735694885254, Validation Loss: 6.156703669974145, Validation Accuracy: {'accuracy': 0.809}


Epoch 63/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 63/150, Training Loss: 6.092894077301025, Validation Loss: 6.155585568001929, Validation Accuracy: {'accuracy': 0.809}


Epoch 64/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 64/150, Training Loss: 6.092036724090576, Validation Loss: 6.154849067647406, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 65/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 65/150, Training Loss: 6.091165065765381, Validation Loss: 6.153898071735464, Validation Accuracy: {'accuracy': 0.8093333333333333}


Epoch 66/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 66/150, Training Loss: 6.090278625488281, Validation Loss: 6.154345410935422, Validation Accuracy: {'accuracy': 0.8083333333333333}


Epoch 67/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 67/150, Training Loss: 6.08937931060791, Validation Loss: 6.1525781357542, Validation Accuracy: {'accuracy': 0.81}


Epoch 68/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 68/150, Training Loss: 6.08846378326416, Validation Loss: 6.151742022088233, Validation Accuracy: {'accuracy': 0.8096666666666666}


Epoch 69/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 69/150, Training Loss: 6.087536334991455, Validation Loss: 6.151483916221781, Validation Accuracy: {'accuracy': 0.8093333333333333}


Epoch 70/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 70/150, Training Loss: 6.086591720581055, Validation Loss: 6.1508679491408325, Validation Accuracy: {'accuracy': 0.8093333333333333}


Epoch 71/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 71/150, Training Loss: 6.085639476776123, Validation Loss: 6.148933177298688, Validation Accuracy: {'accuracy': 0.81}


Epoch 72/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 72/150, Training Loss: 6.084670066833496, Validation Loss: 6.149104869112056, Validation Accuracy: {'accuracy': 0.8103333333333333}


Epoch 73/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 73/150, Training Loss: 6.0836920738220215, Validation Loss: 6.148615456641989, Validation Accuracy: {'accuracy': 0.8096666666666666}


Epoch 74/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 74/150, Training Loss: 6.082701206207275, Validation Loss: 6.148113783369673, Validation Accuracy: {'accuracy': 0.8096666666666666}


Epoch 75/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 75/150, Training Loss: 6.081701278686523, Validation Loss: 6.147463844177571, Validation Accuracy: {'accuracy': 0.8096666666666666}


Epoch 76/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 76/150, Training Loss: 6.080689907073975, Validation Loss: 6.14606911578077, Validation Accuracy: {'accuracy': 0.8103333333333333}


Epoch 77/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 77/150, Training Loss: 6.0796685218811035, Validation Loss: 6.145422976067725, Validation Accuracy: {'accuracy': 0.811}


Epoch 78/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 78/150, Training Loss: 6.078640937805176, Validation Loss: 6.145167497878379, Validation Accuracy: {'accuracy': 0.8106666666666666}


Epoch 79/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 79/150, Training Loss: 6.077601432800293, Validation Loss: 6.145192019482876, Validation Accuracy: {'accuracy': 0.81}


Epoch 80/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 80/150, Training Loss: 6.076559066772461, Validation Loss: 6.143984809834906, Validation Accuracy: {'accuracy': 0.81}


Epoch 81/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 81/150, Training Loss: 6.075505256652832, Validation Loss: 6.143409936986071, Validation Accuracy: {'accuracy': 0.81}


Epoch 82/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 82/150, Training Loss: 6.07444953918457, Validation Loss: 6.143176423742416, Validation Accuracy: {'accuracy': 0.8103333333333333}


Epoch 83/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 83/150, Training Loss: 6.073383331298828, Validation Loss: 6.1424979757755365, Validation Accuracy: {'accuracy': 0.8103333333333333}


Epoch 84/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 84/150, Training Loss: 6.072314739227295, Validation Loss: 6.1411884439752455, Validation Accuracy: {'accuracy': 0.8113333333333334}


Epoch 85/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 85/150, Training Loss: 6.07124137878418, Validation Loss: 6.141430525069541, Validation Accuracy: {'accuracy': 0.81}


Epoch 86/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 86/150, Training Loss: 6.070161819458008, Validation Loss: 6.14092713721255, Validation Accuracy: {'accuracy': 0.81}


Epoch 87/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 87/150, Training Loss: 6.0690813064575195, Validation Loss: 6.140837293990115, Validation Accuracy: {'accuracy': 0.811}


Epoch 88/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 88/150, Training Loss: 6.067997455596924, Validation Loss: 6.140187928017149, Validation Accuracy: {'accuracy': 0.81}


Epoch 89/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 89/150, Training Loss: 6.0669074058532715, Validation Loss: 6.139496047446069, Validation Accuracy: {'accuracy': 0.8113333333333334}


Epoch 90/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 90/150, Training Loss: 6.0658135414123535, Validation Loss: 6.138613766812264, Validation Accuracy: {'accuracy': 0.8106666666666666}


Epoch 91/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 91/150, Training Loss: 6.064720630645752, Validation Loss: 6.137947741975176, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 92/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 92/150, Training Loss: 6.063629150390625, Validation Loss: 6.13880136165213, Validation Accuracy: {'accuracy': 0.8113333333333334}


Epoch 93/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 93/150, Training Loss: 6.062533378601074, Validation Loss: 6.138370042151593, Validation Accuracy: {'accuracy': 0.8116666666666666}


Epoch 94/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 94/150, Training Loss: 6.061437129974365, Validation Loss: 6.137084986301178, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 95/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 95/150, Training Loss: 6.060340881347656, Validation Loss: 6.136653981310256, Validation Accuracy: {'accuracy': 0.8113333333333334}


Epoch 96/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 96/150, Training Loss: 6.059246063232422, Validation Loss: 6.135875950468347, Validation Accuracy: {'accuracy': 0.8126666666666666}


Epoch 97/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 97/150, Training Loss: 6.0581512451171875, Validation Loss: 6.1359346521661635, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 98/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 98/150, Training Loss: 6.05705451965332, Validation Loss: 6.135363710687516, Validation Accuracy: {'accuracy': 0.813}


Epoch 99/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 99/150, Training Loss: 6.0559611320495605, Validation Loss: 6.135419647744361, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 100/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 100/150, Training Loss: 6.054863929748535, Validation Loss: 6.134247850864492, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 101/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 101/150, Training Loss: 6.053781032562256, Validation Loss: 6.1332788264497795, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 102/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 102/150, Training Loss: 6.052691459655762, Validation Loss: 6.134771484009763, Validation Accuracy: {'accuracy': 0.813}


Epoch 103/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 103/150, Training Loss: 6.051606178283691, Validation Loss: 6.133568266604809, Validation Accuracy: {'accuracy': 0.813}


Epoch 104/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 104/150, Training Loss: 6.050518989562988, Validation Loss: 6.133278425703657, Validation Accuracy: {'accuracy': 0.8123333333333334}


Epoch 105/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 105/150, Training Loss: 6.049442768096924, Validation Loss: 6.132481717048807, Validation Accuracy: {'accuracy': 0.8136666666666666}


Epoch 106/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 106/150, Training Loss: 6.048366546630859, Validation Loss: 6.133066948423994, Validation Accuracy: {'accuracy': 0.8136666666666666}


Epoch 107/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 107/150, Training Loss: 6.047289848327637, Validation Loss: 6.132592683142804, Validation Accuracy: {'accuracy': 0.8133333333333334}


Epoch 108/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 108/150, Training Loss: 6.046220779418945, Validation Loss: 6.132059350926825, Validation Accuracy: {'accuracy': 0.814}


Epoch 109/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 109/150, Training Loss: 6.045151710510254, Validation Loss: 6.13172690919105, Validation Accuracy: {'accuracy': 0.814}


Epoch 110/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 110/150, Training Loss: 6.044090747833252, Validation Loss: 6.131072749482825, Validation Accuracy: {'accuracy': 0.8143333333333334}


Epoch 111/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 111/150, Training Loss: 6.04302453994751, Validation Loss: 6.130338937678236, Validation Accuracy: {'accuracy': 0.814}


Epoch 112/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 112/150, Training Loss: 6.041974067687988, Validation Loss: 6.129663350734305, Validation Accuracy: {'accuracy': 0.814}


Epoch 113/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 113/150, Training Loss: 6.040921211242676, Validation Loss: 6.130112161027625, Validation Accuracy: {'accuracy': 0.814}


Epoch 114/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 114/150, Training Loss: 6.039876937866211, Validation Loss: 6.129662599969418, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 115/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 115/150, Training Loss: 6.0388336181640625, Validation Loss: 6.129495285926981, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 116/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 116/150, Training Loss: 6.0377960205078125, Validation Loss: 6.128529320371912, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 117/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 117/150, Training Loss: 6.036762237548828, Validation Loss: 6.128768286806472, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 118/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 118/150, Training Loss: 6.035732746124268, Validation Loss: 6.128724788097625, Validation Accuracy: {'accuracy': 0.8146666666666667}


Epoch 119/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 119/150, Training Loss: 6.034713268280029, Validation Loss: 6.128567604308433, Validation Accuracy: {'accuracy': 0.8143333333333334}


Epoch 120/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 120/150, Training Loss: 6.033694267272949, Validation Loss: 6.128755356403107, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 121/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 121/150, Training Loss: 6.032680988311768, Validation Loss: 6.127453631543099, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 122/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 122/150, Training Loss: 6.031673431396484, Validation Loss: 6.1270443936611745, Validation Accuracy: {'accuracy': 0.815}


Epoch 123/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 123/150, Training Loss: 6.030673027038574, Validation Loss: 6.126750986626807, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 124/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 124/150, Training Loss: 6.029676914215088, Validation Loss: 6.1274028889676355, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 125/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 125/150, Training Loss: 6.028688907623291, Validation Loss: 6.12631837357866, Validation Accuracy: {'accuracy': 0.8153333333333334}


Epoch 126/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 126/150, Training Loss: 6.027701377868652, Validation Loss: 6.126682819204127, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 127/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 127/150, Training Loss: 6.0267229080200195, Validation Loss: 6.126298534109237, Validation Accuracy: {'accuracy': 0.815}


Epoch 128/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 128/150, Training Loss: 6.025749683380127, Validation Loss: 6.126504116869987, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 129/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 129/150, Training Loss: 6.024780750274658, Validation Loss: 6.125303821360811, Validation Accuracy: {'accuracy': 0.8156666666666667}


Epoch 130/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 130/150, Training Loss: 6.02381706237793, Validation Loss: 6.125590131637898, Validation Accuracy: {'accuracy': 0.816}


Epoch 131/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 131/150, Training Loss: 6.022860050201416, Validation Loss: 6.125463759645503, Validation Accuracy: {'accuracy': 0.8166666666666667}


Epoch 132/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 132/150, Training Loss: 6.02191162109375, Validation Loss: 6.1249594384051385, Validation Accuracy: {'accuracy': 0.817}


Epoch 133/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 133/150, Training Loss: 6.020966529846191, Validation Loss: 6.124204458074367, Validation Accuracy: {'accuracy': 0.817}


Epoch 134/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 134/150, Training Loss: 6.020026683807373, Validation Loss: 6.125196482272858, Validation Accuracy: {'accuracy': 0.8173333333333334}


Epoch 135/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 135/150, Training Loss: 6.019099235534668, Validation Loss: 6.1249006707617575, Validation Accuracy: {'accuracy': 0.817}


Epoch 136/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 136/150, Training Loss: 6.0181660652160645, Validation Loss: 6.124265604830803, Validation Accuracy: {'accuracy': 0.817}


Epoch 137/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 137/150, Training Loss: 6.017245769500732, Validation Loss: 6.1242533338830825, Validation Accuracy: {'accuracy': 0.818}


Epoch 138/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 138/150, Training Loss: 6.016334533691406, Validation Loss: 6.122938866310931, Validation Accuracy: {'accuracy': 0.817}


Epoch 139/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 139/150, Training Loss: 6.015423774719238, Validation Loss: 6.123682093113027, Validation Accuracy: {'accuracy': 0.818}


Epoch 140/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 140/150, Training Loss: 6.014520645141602, Validation Loss: 6.123543536409419, Validation Accuracy: {'accuracy': 0.8186666666666667}


Epoch 141/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 141/150, Training Loss: 6.0136260986328125, Validation Loss: 6.122840739311056, Validation Accuracy: {'accuracy': 0.818}


Epoch 142/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 142/150, Training Loss: 6.0127339363098145, Validation Loss: 6.123146858621151, Validation Accuracy: {'accuracy': 0.8186666666666667}


Epoch 143/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 143/150, Training Loss: 6.011851787567139, Validation Loss: 6.122749151067531, Validation Accuracy: {'accuracy': 0.8186666666666667}


Epoch 144/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 148/150, Training Loss: 6.0075178146362305, Validation Loss: 6.12186088460557, Validation Accuracy: {'accuracy': 0.8173333333333334}


Epoch 149/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 149/150, Training Loss: 6.006669998168945, Validation Loss: 6.1210163248346205, Validation Accuracy: {'accuracy': 0.819}


Epoch 150/ 150:   0%|          | 0/1563 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

