## Load the dataset


In [1]:
from VQA_Datasetv2 import VQA_Dataset
import clip
import torch
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model, preprocess = clip.load("ViT-B/32", device=device)

dataset = VQA_Dataset()
dataset.load_all(preprocess, device, length=200)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


Using cuda device


Preprocessing Images: 100%|██████████| 200/200 [00:02<00:00, 72.04it/s]


## Simple model architectue


In [2]:
import tqdm

class VQA_Model(torch.nn.Module):
    """uses the model from clip like in evaluation method"""
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.device = device

    def forward(self, image, question_tokens, answer_tokens):
        """returns the logits for the answers"""
        image_features = model.encode_image(image)
        question_features = model.encode_text(question_tokens)
        answer_features = model.encode_text(answer_tokens)
        
        answer_features /= answer_features.norm(dim=-1, keepdim=True)
        
        combined_features = image_features * question_features
        combined_features /= combined_features.norm(dim=-1, keepdim=True)

        similarity = (100.0 * combined_features @ answer_features.T).softmax(dim=-1)
        return similarity

    


In [51]:
class VQA_Model2(torch.nn.Module):
    """architecture that uses clip and a small NN to combine question and image features to get a embedding of the same size as the answer embedding
        Only train the NN, not the clip model
    """
    def __init__(self, model, device):
        super().__init__()
        self.model = model
        self.device = device
        self.fc1 = torch.nn.Linear(1024, 512).to(self.device)
        self.initialize_parameters()

    def initialize_parameters(self):
        # Apply Xavier/Glorot initialization to the linear layer
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.zeros_(self.fc1.bias.data)

    def forward(self, image, question_tokens, answer_tokens):
        """returns the logits for the answers"""
        image_features = model.encode_image(image)
        question_features = model.encode_text(question_tokens)
        answer_features = model.encode_text(answer_tokens)
        
        answer_features /= answer_features.norm(dim=-1, keepdim=True)
        
        # test if nan in answer features
        if torch.isnan(answer_features).any():
            print("answer features are nan")

        # concatenate the features
        combined_features = torch.cat((image_features, question_features), dim=1).to(self.device)  
        if torch.isnan(combined_features).any():
            print("combined features are nan 1")
        combined_features /= combined_features.norm(dim=-1, keepdim=True) 
        if torch.isnan(combined_features).any():
            print("combined features are nan 2")
        combined_features = combined_features.to(torch.float32)
        combined_features = self.fc1(combined_features)
        
        # test if nan in combined features
        if torch.isnan(combined_features).any():
            print("combined features are nan 3")
        combined_features = combined_features.to(torch.float16)

        similarity =(combined_features @ answer_features.T)# .softmax(dim=-1) no softmax because of cross entropy loss; without multiplying by 100
        
        return similarity

## Evaluate

In [13]:
import tqdm

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    for i, data in enumerate(tqdm.tqdm(dataloader)):
        #image = data['image'].to(device)
        #question_tokens = data['question_tokens'].to(device)
        #answer_tokens = data['answer_tokens'].squeeze(0).to(device)
        #image_id = int(data['image_id'])
        #correct_answer = int(data['correct_answer_idx'])
        image = data[0].to(device)
        answer_tokens = data[1].squeeze(0).to(device)
        question_tokens = data[2].squeeze(0).to(device)
        correct_answer = int(data[3])

        with torch.no_grad():
            similarity = model(image, question_tokens, answer_tokens)
            pred = similarity.argmax(dim=-1).item()

            if pred == correct_answer:
                correct += 1
    print(f"Accuracy: {correct / len(dataloader)} correct: {correct} total: {len(dataloader)}")


In [19]:
combined_model = VQA_Model(model, device)
evaluate(combined_model, dataloader, device)

100%|██████████| 200/200 [00:05<00:00, 38.06it/s]

Accuracy: 0.24 correct: 48 total: 200





## Training

In [56]:
def train(model, dataloader, device, epochs=10):
    model.train()
    print(model.parameters())
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()
    for epoch in tqdm.tqdm(range(epochs)):
        for i, data in enumerate(tqdm.tqdm(dataloader)):
            #image = data['image'].to(device)
            #image = image.squeeze(0)
            #question_tokens = data['question_tokens'].squeeze(0).to(device)
            #answer_tokens = data['answer_tokens'].squeeze(0).to(device)
            #image_id = int(data['image_id'])
            #correct_answer = int(data['correct_answer_idx'])
            image = data[0].to(device)
            answer_tokens = data[1].squeeze(0).to(device)
            question_tokens = data[2].squeeze(0).to(device)
            correct_answer = int(data[3])
            optimizer.zero_grad()
            similarity = model(image, question_tokens, answer_tokens)

            # transform asnwer to tensor of the same shape as similarity before only correct index
            one_hot_encoding = torch.zeros(similarity.squeeze(0).shape[0])
            one_hot_encoding[correct_answer] = 1
            one_hot_encoding = one_hot_encoding.to(device)
            

            loss = loss_fn(similarity.squeeze(0), one_hot_encoding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0, error_if_nonfinite=True)
            optimizer.step()
            #print("--------------------------------------------------")
        print(f"Epoch {epoch} loss: {loss.item()}")
        #evaluate(model, dataloader, device)

In [58]:
trained_model = VQA_Model2(model, device)
# freeze the clip model
for param in trained_model.model.parameters():
    param.requires_grad = False

train(trained_model, dataloader, device, epochs=100)

# evaluate the model
evaluate(trained_model, dataloader, device)

<generator object Module.parameters at 0x00000284FB8411C0>


100%|██████████| 200/200 [00:06<00:00, 33.17it/s]


Epoch 0 loss: 2.91015625


100%|██████████| 200/200 [00:05<00:00, 34.87it/s]


Epoch 1 loss: 2.91015625


100%|██████████| 200/200 [00:05<00:00, 34.86it/s]


Epoch 2 loss: 2.9140625


100%|██████████| 200/200 [00:05<00:00, 34.09it/s]


Epoch 3 loss: 2.919921875


100%|██████████| 200/200 [00:05<00:00, 33.88it/s]


Epoch 4 loss: 2.92578125


100%|██████████| 200/200 [00:05<00:00, 34.58it/s]


Epoch 5 loss: 2.931640625


100%|██████████| 200/200 [00:05<00:00, 36.79it/s]


Epoch 6 loss: 2.93359375


100%|██████████| 200/200 [00:05<00:00, 36.33it/s]


Epoch 7 loss: 2.935546875


100%|██████████| 200/200 [00:05<00:00, 36.33it/s]


Epoch 8 loss: 2.931640625


100%|██████████| 200/200 [00:05<00:00, 35.97it/s]


Epoch 9 loss: 2.927734375


100%|██████████| 200/200 [00:05<00:00, 36.99it/s]


Epoch 10 loss: 2.921875


100%|██████████| 200/200 [00:05<00:00, 36.82it/s]


Epoch 11 loss: 2.912109375


100%|██████████| 200/200 [00:05<00:00, 36.91it/s]


Epoch 12 loss: 2.90234375


100%|██████████| 200/200 [00:05<00:00, 36.02it/s]


Epoch 13 loss: 2.890625


100%|██████████| 200/200 [00:05<00:00, 36.03it/s]


Epoch 14 loss: 2.876953125


100%|██████████| 200/200 [00:05<00:00, 35.62it/s]


Epoch 15 loss: 2.86328125


100%|██████████| 200/200 [00:05<00:00, 34.36it/s]


Epoch 16 loss: 2.84765625


100%|██████████| 200/200 [00:05<00:00, 34.14it/s]


Epoch 17 loss: 2.83203125


100%|██████████| 200/200 [00:05<00:00, 35.24it/s]


Epoch 18 loss: 2.81640625


100%|██████████| 200/200 [00:05<00:00, 34.58it/s]


Epoch 19 loss: 2.798828125


100%|██████████| 200/200 [00:05<00:00, 33.96it/s]


Epoch 20 loss: 2.78125


100%|██████████| 200/200 [00:05<00:00, 33.79it/s]


Epoch 21 loss: 2.765625


100%|██████████| 200/200 [00:05<00:00, 36.10it/s]


Epoch 22 loss: 2.748046875


100%|██████████| 200/200 [00:05<00:00, 36.90it/s]


Epoch 23 loss: 2.73046875


100%|██████████| 200/200 [00:05<00:00, 36.23it/s]


Epoch 24 loss: 2.712890625


100%|██████████| 200/200 [00:05<00:00, 34.32it/s]


Epoch 25 loss: 2.693359375


100%|██████████| 200/200 [00:05<00:00, 35.12it/s]


Epoch 26 loss: 2.67578125


100%|██████████| 200/200 [00:05<00:00, 36.93it/s]


Epoch 27 loss: 2.658203125


100%|██████████| 200/200 [00:05<00:00, 36.82it/s]


Epoch 28 loss: 2.640625


100%|██████████| 200/200 [00:05<00:00, 36.37it/s]


Epoch 29 loss: 2.62109375


100%|██████████| 200/200 [00:05<00:00, 36.41it/s]


Epoch 30 loss: 2.603515625


100%|██████████| 200/200 [00:05<00:00, 37.53it/s]


Epoch 31 loss: 2.5859375


100%|██████████| 200/200 [00:05<00:00, 37.22it/s]


Epoch 32 loss: 2.568359375


100%|██████████| 200/200 [00:05<00:00, 36.78it/s]


Epoch 33 loss: 2.55078125


100%|██████████| 200/200 [00:05<00:00, 36.24it/s]


Epoch 34 loss: 2.53125


100%|██████████| 200/200 [00:05<00:00, 36.51it/s]


Epoch 35 loss: 2.513671875


100%|██████████| 200/200 [00:05<00:00, 35.81it/s]


Epoch 36 loss: 2.49609375


100%|██████████| 200/200 [00:05<00:00, 35.02it/s]


Epoch 37 loss: 2.4765625


100%|██████████| 200/200 [00:05<00:00, 34.43it/s]


Epoch 38 loss: 2.458984375


100%|██████████| 200/200 [00:05<00:00, 34.28it/s]


Epoch 39 loss: 2.439453125


100%|██████████| 200/200 [00:05<00:00, 34.29it/s]


Epoch 40 loss: 2.421875


100%|██████████| 200/200 [00:05<00:00, 34.34it/s]


Epoch 41 loss: 2.404296875


100%|██████████| 200/200 [00:05<00:00, 33.94it/s]


Epoch 42 loss: 2.384765625


100%|██████████| 200/200 [00:05<00:00, 36.08it/s]


Epoch 43 loss: 2.365234375


100%|██████████| 200/200 [00:05<00:00, 37.37it/s]


Epoch 44 loss: 2.34765625


100%|██████████| 200/200 [00:05<00:00, 36.90it/s]


Epoch 45 loss: 2.328125


100%|██████████| 200/200 [00:05<00:00, 36.35it/s]


Epoch 46 loss: 2.310546875


 22%|██▏       | 44/200 [00:01<00:04, 36.34it/s]