## Load the dataset


In [51]:
from VQA_Datasetv2 import VQA_Dataset
import clip
import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model, preprocess = clip.load("ViT-B/32", device=device)

dataset = VQA_Dataset()
dataset.load_all(preprocess, device, length=500)
train_val_size = int(len(dataset)*0.8)
train_size = int(len(dataset)*0.8*0.8)
val_size = int(len(dataset)*0.8)-train_size
test_size = int(len(dataset))-train_val_size
print("Train size: ", train_size)
print("Test size: ", test_size)
print("Val size: ", val_size)
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])

batch_size=1
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

Using cuda device


Preprocessing Images: 100%|██████████| 500/500 [00:06<00:00, 71.44it/s]

Train size:  320
Test size:  100
Val size:  80





## Simple model architectue


In [52]:
import tqdm

class VQA_Model(torch.nn.Module):
    """uses the model from clip like in evaluation method"""
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.device = device

    def forward(self, image, question_tokens, answer_tokens):
        """returns the logits for the answers"""
        image_features = model.encode_image(image)
        question_features = model.encode_text(question_tokens)
        answer_features = model.encode_text(answer_tokens)
        
        answer_features /= answer_features.norm(dim=-1, keepdim=True)
        
        combined_features = image_features * question_features
        combined_features /= combined_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * combined_features @ answer_features.T).softmax(dim=-1)
        return similarity

    


In [53]:
class VQA_Model2(torch.nn.Module):
    """architecture that uses clip and a small NN to combine question and image features to get a embedding of the same size as the answer embedding
        Only train the NN, not the clip model
    """
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.device = device
        self.fc1 = torch.nn.Linear(1024, 512).to(self.device)
        self.initialize_parameters()

    def initialize_parameters(self):
        # Apply Xavier/Glorot initialization to the linear layer
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.zeros_(self.fc1.bias.data)

    def forward(self, image, question_tokens, answer_tokens):
        """returns the logits for the answers"""
        image_features = model.encode_image(image)
        question_features = model.encode_text(question_tokens)
        answer_features = model.encode_text(answer_tokens)
        
        answer_features /= answer_features.norm(dim=-1, keepdim=True)

        # concatenate the features
        combined_features = torch.cat((image_features, question_features), dim=1).to(self.device)  
        combined_features /= combined_features.norm(dim=-1, keepdim=True) 
        combined_features = combined_features.to(torch.float32)
        combined_features = self.fc1(combined_features)
        
        combined_features = combined_features.to(torch.float16)

        similarity =(combined_features @ answer_features.T)# .softmax(dim=-1) no softmax because of cross entropy loss; without multiplying by 100
        
        return similarity

## Evaluate

In [64]:
import tqdm

def evaluate(model, dataloader, device, show_progress=False):
    model.eval()
    correct = 0
    if show_progress:
        pbar = tqdm.tqdm(dataloader)
    else:
        pbar = dataloader
    for i, data in enumerate(pbar):
        image = data[0].to(device)
        answer_tokens = data[1].squeeze(0).to(device)
        question_tokens = data[2].squeeze(0).to(device)
        if dataloader.batch_size == 1:
            correct_answer = int(data[3])
        else:
            correct_answer = [int(x) for x in data[3]]

        with torch.no_grad():
            similarity = model(image, question_tokens, answer_tokens)
            pred = similarity.argmax(dim=-1).item()

            if pred == correct_answer:
                correct += 1
    return correct/len(dataloader)


In [65]:
combined_model = VQA_Model2(model, device)
evaluate(combined_model, test_dataloader, device, show_progress=True)

100%|██████████| 100/100 [00:02<00:00, 37.69it/s]


0.03

## Training

In [66]:
def train(model, train_dataloader, val_dataloader, device, epochs=10):
    model.train()
    print(model.parameters())
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        for i, data in enumerate(train_dataloader):

            image = data[0].to(device)
            answer_tokens = data[1].squeeze(0).to(device)
            question_tokens = data[2].squeeze(0).to(device)
            if train_dataloader.batch_size == 1:
                correct_answer = int(data[3])
            else:
                correct_answer = [int(x) for x in data[3]]
            
            optimizer.zero_grad()
            similarity = model(image, question_tokens, answer_tokens)

            # transform asnwer to tensor of the same shape as similarity before only correct index
            one_hot_encoding = torch.zeros(similarity.squeeze(0).shape[0])
            one_hot_encoding[correct_answer] = 1
            one_hot_encoding = one_hot_encoding.to(device)
            

            loss = loss_fn(similarity.squeeze(0), one_hot_encoding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0, error_if_nonfinite=True)
            optimizer.step()
            #print("--------------------------------------------------")
        acc = evaluate(model, val_dataloader, device)
        pbar.set_description(f"Epoch {epoch} loss: {loss.item()}, acc: {acc}")

In [67]:
trained_model = VQA_Model2(model, device)
# freeze the clip model
for param in trained_model.model.parameters():
    param.requires_grad = False

train(trained_model, train_dataloader, val_dataloader, device, epochs=30)

# evaluate the model
evaluate(trained_model, test_dataloader, device)

<generator object Module.parameters at 0x0000017C3118E570>


Epoch 21 loss: 1.3447265625, acc: 0.25:  73%|███████▎  | 22/30 [03:28<01:15,  9.49s/it]  