In [66]:
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import clip
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

**Dataset**

In [67]:
#Dataset with encodings already computed
from VQA_Datasetv2 import VQA_Dataset_preloaded
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

dataset = VQA_Dataset_preloaded()

#Computing
#dataset.compute_store(preprocess, model, device, "full", length=4000)

#Loading h5 file
dataset.load("full", device)

**Test-train split**

In [68]:
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import numpy as np

train_size = int(len(dataset)*0.8)
val_size = int(len(dataset)*0.1)
test_size = int(len(dataset))-train_size-val_size
generator = torch.Generator().manual_seed(1)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)
print("Train size: ", train_size)
print("Test size: ", test_size)

batch_size=2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Train size:  3200
Test size:  400


**Model**

In [69]:
from models import VQA_Model_Precalc
from torch.utils.tensorboard import SummaryWriter
import os 

#Tensorboard
currentModelIteration = "1e-4_relu_hidden"
folder_path = os.path.join("runs", "trainings", currentModelIteration)
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
writer = SummaryWriter(folder_path)

#CLIP
clip_model, preprocess = clip.load('ViT-B/32', device)
vqa_model = VQA_Model_Precalc(clip_model, device)

# Freezing clip model: Without this, gradients scales were different (clip and mlp gradients), generated exploding gradient 
for param in vqa_model.model.parameters():
    param.requires_grad = False

In [70]:
import os
from PIL import Image
from torchvision import transforms

img = Image.open(os.path.join("Images", "abstract_v002_val2015_000000029903.png"))
image_input = preprocess(img).unsqueeze(0).to(device)
print("Image preprocessed: ",image_input.shape)

image_features = clip_model.encode_image(image_input)
print("Image encoded size: ", image_features.shape)

text = clip.tokenize(["a diagram of the dof" , "a dog", "a cat"]).to(device)
print("Text tokenized size: ",text.shape)

text_features = clip_model.encode_text(text)
print("Text encoded size: ",text_features.shape)

Image preprocessed:  torch.Size([1, 3, 224, 224])
Image encoded size:  torch.Size([1, 512])
Text tokenized size:  torch.Size([3, 77])
Text encoded size:  torch.Size([3, 512])


**Training & Optim**

In [76]:
def train(dataloader, vqa_model, loss_function, optimizer, clip_value, epoch):
    size = len(dataloader.dataset)
    vqa_model.train()
    train_cost_acum = 0.0

    for batch, (data) in enumerate(dataloader):    
        images = data[0].squeeze(1)
        question_tokens = data[2].squeeze(1)
        answer_tokens = data[1].squeeze(1)
        similarity_pred = vqa_model(images, question_tokens, answer_tokens)

        similarity_label_arg = data[3].squeeze(1).long().to(device)
        
        loss = loss_function(similarity_pred, similarity_label_arg)

        #Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(vqa_model.parameters(), max_norm=clip_value, error_if_nonfinite=True)

        optimizer.step()
        
        #Ploting results
        train_cost_acum += loss
        if batch % 50 == 1:
            writer.add_scalar('training loss', float(train_cost_acum)/batch, epoch * size + batch) #len(dataloader) returns total number of batchs in an epoch
            loss, current = loss.item(), batch*len(images)
            print("loss: ", loss, current, size)
   

In [77]:
def eval(dataloader, model, loss_function, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for data in dataloader:    
            images = data[0].squeeze(1).to(device)
            question_tokens = data[2].squeeze(1).to(device)
            answer_tokens = data[1].squeeze(1).to(device)
            similarity_pred = model(images, question_tokens, answer_tokens)

            similarity_label_arg = torch.tensor(data[3]).long().squeeze(1).to(device)
            
            val_loss += loss_function(similarity_pred, similarity_label_arg)
            correct += (similarity_pred.argmax(1) == similarity_label_arg).type(torch.float).sum().item()
            
    val_loss /= num_batches
    correct /= size

    #Ploting results
    writer.add_scalar('Accuracy/test', correct*100, epoch)
    writer.add_scalar('Loss/test', val_loss, epoch)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")
    return 100*correct

In [78]:
#Hyperparameters and optim
from torch import nn

loss_fn = nn.CrossEntropyLoss()
clip_value = 1.0
learning_rate = 1e-4
optimizer = torch.optim.Adam(vqa_model.parameters(), lr=learning_rate)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateu(optimizer, patience=5, verbose=True) #Scheduler

In [79]:
#Early stopping parameters
import os

n_epochs = 50
early_stop_threshhold = 5
best_accuracy = -1
best_epoch = -1

def checkpoint(model, filename):
    folder_path = os.path.join("runs", "best_model", currentModelIteration)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    torch.save(model.state_dict(), os.path.join(folder_path, filename))
    
def resume(model, filename):
    model.load_state_dict(torch.load(os.path.join("runs", "checkpoint_SolvingCropping", filename)))

In [80]:
for epoch in range(n_epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train(train_dataloader, vqa_model, loss_fn, optimizer, clip_value, epoch)
    acc = eval(val_dataloader, vqa_model, loss_fn, epoch)
    if acc>best_accuracy:
        best_accuracy = acc
        best_epoch = epoch
        checkpoint(vqa_model, "best_model.pth")
    elif (epoch-best_epoch) > early_stop_threshhold:
        print("--Early stopped training--")
        break

Epoch 1
-------------------------------
loss:  15.603597640991211 2 3200
loss:  4.96431827545166 102 3200
loss:  4.798961639404297 202 3200
loss:  9.406116485595703 302 3200
loss:  0.8838605880737305 402 3200
loss:  9.867145538330078 502 3200
loss:  6.355923175811768 602 3200
loss:  1.0071303844451904 702 3200
loss:  8.484251022338867 802 3200
loss:  0.28028714656829834 902 3200
loss:  0.9410175085067749 1002 3200
loss:  0.5865575671195984 1102 3200
loss:  0.7346488833427429 1202 3200
loss:  1.618565320968628 1302 3200
loss:  3.8468565940856934 1402 3200
loss:  1.0159099102020264 1502 3200
loss:  1.8026976585388184 1602 3200
loss:  0.8143121004104614 1702 3200
loss:  4.967896938323975 1802 3200
loss:  2.2322518825531006 1902 3200
loss:  2.0141613483428955 2002 3200
loss:  1.2832977771759033 2102 3200
loss:  2.808457851409912 2202 3200
loss:  2.4589812755584717 2302 3200
loss:  1.4496240615844727 2402 3200
loss:  1.3651671409606934 2502 3200
loss:  4.000920295715332 2602 3200
loss:  1.0

  similarity_label_arg = torch.tensor(data[3]).long().squeeze(1).to(device)


Test Error: 
 Accuracy: 38.2%, Avg loss: 2.083412 

Epoch 2
-------------------------------
loss:  0.11003571003675461 2 3200
loss:  2.526693820953369 102 3200
loss:  0.5331398844718933 202 3200
loss:  0.6288391351699829 302 3200
loss:  1.216438889503479 402 3200
loss:  2.254531145095825 502 3200
loss:  1.4811255931854248 602 3200
loss:  2.754563093185425 702 3200
loss:  2.7307422161102295 802 3200
loss:  3.2887728214263916 902 3200
loss:  0.08047059178352356 1002 3200
loss:  2.499587297439575 1102 3200
loss:  0.1906214952468872 1202 3200
loss:  3.547283172607422 1302 3200
loss:  0.9744843244552612 1402 3200
loss:  1.8764451742172241 1502 3200
loss:  1.4640984535217285 1602 3200
loss:  4.074671745300293 1702 3200
loss:  1.1250826120376587 1802 3200
loss:  0.5008111000061035 1902 3200
loss:  0.0626831203699112 2002 3200
loss:  1.6606991291046143 2102 3200
loss:  1.192213773727417 2202 3200
loss:  0.36414051055908203 2302 3200
loss:  0.10931529104709625 2402 3200
loss:  1.751401424407959

KeyboardInterrupt: 