## Load the dataset


In [1]:
from VQA_Datasetv2 import VQA_Dataset
import clip
import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model, preprocess = clip.load("ViT-B/32", device=device)

dataset = VQA_Dataset()
dataset.load_all(preprocess, device, length=500)
train_val_size = int(len(dataset)*0.8)
train_size = int(len(dataset)*0.8*0.8)
val_size = int(len(dataset)*0.8)-train_size
test_size = int(len(dataset))-train_val_size
print("Train size: ", train_size)
print("Test size: ", test_size)
print("Val size: ", val_size)
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])

batch_size=2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

  from .autonotebook import tqdm as notebook_tqdm


Using cpu device


Preprocessing Images: 100%|██████████| 500/500 [00:08<00:00, 61.08it/s]


Train size:  320
Test size:  100
Val size:  80


## Simple model architectue


In [2]:
from models import VQA_Model2



## Evaluate

In [7]:
import tqdm

def evaluate(model, dataloader, device, show_progress=False):
    model.eval()
    correct = 0
    if show_progress:
        pbar = tqdm.tqdm(dataloader)
    else:
        pbar = dataloader
    for i, data in enumerate(pbar):
        image = data[0].to(device)
        answer_tokens = data[1].squeeze(0).to(device)
        question_tokens = data[2].squeeze(1).to(device)
        if dataloader.batch_size == 1:
            correct_answer = [int(data[3])]
        else:
            correct_answer = [int(x) for x in data[3]]

        with torch.no_grad():
            similarity = model(image, question_tokens, answer_tokens)
            pred = similarity.argmax(dim=-1)

            if pred == correct_answer:
                correct += 1
    return correct/len(dataloader)


In [8]:
combined_model = VQA_Model2(model, device)
evaluate(combined_model, test_dataloader, device, show_progress=True)

  2%|▏         | 1/50 [00:01<01:29,  1.82s/it]

torch.Size([2, 18])
tensor([3, 5])


  4%|▍         | 2/50 [00:03<01:27,  1.82s/it]

torch.Size([2, 18])
tensor([1, 7])


  6%|▌         | 3/50 [00:05<01:28,  1.87s/it]

torch.Size([2, 18])
tensor([ 2, 11])


  8%|▊         | 4/50 [00:07<01:25,  1.85s/it]

torch.Size([2, 18])
tensor([3, 2])


 10%|█         | 5/50 [00:09<01:21,  1.82s/it]

torch.Size([2, 18])
tensor([10, 16])


 12%|█▏        | 6/50 [00:10<01:19,  1.80s/it]

torch.Size([2, 18])
tensor([0, 7])


 14%|█▍        | 7/50 [00:12<01:17,  1.79s/it]

torch.Size([2, 18])
tensor([ 4, 15])


 16%|█▌        | 8/50 [00:14<01:16,  1.83s/it]

torch.Size([2, 18])
tensor([10,  5])


 18%|█▊        | 9/50 [00:16<01:15,  1.84s/it]

torch.Size([2, 18])
tensor([ 2, 15])


 20%|██        | 10/50 [00:18<01:14,  1.87s/it]

torch.Size([2, 18])
tensor([17,  6])


 22%|██▏       | 11/50 [00:20<01:14,  1.91s/it]

torch.Size([2, 18])
tensor([10, 14])


 24%|██▍       | 12/50 [00:22<01:12,  1.90s/it]

torch.Size([2, 18])
tensor([12,  3])


 26%|██▌       | 13/50 [00:23<01:08,  1.84s/it]

torch.Size([2, 18])
tensor([3, 1])


 28%|██▊       | 14/50 [00:25<01:04,  1.78s/it]

torch.Size([2, 18])
tensor([16,  0])


 30%|███       | 15/50 [00:27<01:00,  1.74s/it]

torch.Size([2, 18])
tensor([10, 16])


 32%|███▏      | 16/50 [00:28<00:58,  1.72s/it]

torch.Size([2, 18])
tensor([12, 15])


 34%|███▍      | 17/50 [00:30<00:56,  1.73s/it]

torch.Size([2, 18])
tensor([4, 7])


 36%|███▌      | 18/50 [00:32<00:55,  1.73s/it]

torch.Size([2, 18])
tensor([11,  1])


 38%|███▊      | 19/50 [00:34<00:53,  1.72s/it]

torch.Size([2, 18])
tensor([ 1, 16])


 40%|████      | 20/50 [00:35<00:52,  1.75s/it]

torch.Size([2, 18])
tensor([5, 9])


 42%|████▏     | 21/50 [00:37<00:52,  1.82s/it]

torch.Size([2, 18])
tensor([17,  8])


 44%|████▍     | 22/50 [00:40<00:56,  2.03s/it]

torch.Size([2, 18])
tensor([3, 7])


 46%|████▌     | 23/50 [00:42<00:58,  2.16s/it]

torch.Size([2, 18])
tensor([12, 16])


 48%|████▊     | 24/50 [00:45<00:56,  2.19s/it]

torch.Size([2, 18])
tensor([13,  6])


 50%|█████     | 25/50 [00:47<00:57,  2.28s/it]

torch.Size([2, 18])
tensor([1, 7])


 52%|█████▏    | 26/50 [00:50<00:46,  1.93s/it]

torch.Size([2, 18])
tensor([9, 0])





KeyboardInterrupt: 

## Training

In [None]:
def train(model, train_dataloader, val_dataloader, device, epochs=10, patience=3):
    model.train()
    print(model.parameters())
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()
    pbar = tqdm.tqdm(range(epochs))
    prev_acc = 0.0
    count = 0
    for epoch in pbar:
        for i, data in enumerate(train_dataloader):

            image = data[0].to(device)
            answer_tokens = data[1].squeeze(0).to(device)
            question_tokens = data[2].squeeze(0).to(device)
            if train_dataloader.batch_size == 1:
                correct_answer = int(data[3])
            else:
                correct_answer = [int(x) for x in data[3]]
            
            optimizer.zero_grad()
            similarity = model(image, question_tokens, answer_tokens)

            # transform asnwer to tensor of the same shape as similarity before only correct index
            one_hot_encoding = torch.zeros(similarity.squeeze(0).shape[0])
            one_hot_encoding[correct_answer] = 1
            one_hot_encoding = one_hot_encoding.to(device)
            

            loss = loss_fn(similarity.squeeze(0), one_hot_encoding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=4.0, error_if_nonfinite=True)
            optimizer.step()
            #print("--------------------------------------------------")
        acc = evaluate(model, val_dataloader, device)
        if acc < prev_acc:
            count += 1
        prev_acc = acc
        pbar.set_description(f"Epoch {epoch} loss: {loss.item()}, acc: {acc}")
        if count == patience:
            print("early stopping")
            break 

In [None]:
trained_model = VQA_Model2(model, device)
# freeze the clip model
for param in trained_model.model.parameters():
    param.requires_grad = False

train(trained_model, train_dataloader, val_dataloader, device, epochs=1)

# evaluate the model
evaluate(trained_model, test_dataloader, device)

<generator object Module.parameters at 0x000001EAAC2C97E0>


Epoch 0 loss: 2.5968213081359863, acc: 0.175: 100%|██████████| 1/1 [07:21<00:00, 441.73s/it]


0.23

In [None]:

from PIL import Image

for model_name in clip.available_models():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device and model: {model_name}")
    model, preprocess = clip.load(model_name, device=device)
    
    text_tokens = clip.tokenize(["question_text"]).to(device)
    image = preprocess(Image.open("CLIP.jpg")).unsqueeze(0).to(device)
    image_features = model.encode_image(image)
    text_features = model.encode_text(text_tokens)
    print(f"text_t {text_tokens.dtype}, text_f {text_features.dtype}, image_t {image.dtype}, image_f {image_features.dtype}")
    



Using cpu device and model: RN50
text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: RN101


100%|███████████████████████████████████████| 278M/278M [01:33<00:00, 3.13MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: RN50x4


100%|███████████████████████████████████████| 402M/402M [02:19<00:00, 3.02MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: RN50x16


100%|███████████████████████████████████████| 630M/630M [04:41<00:00, 2.35MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: RN50x64


100%|█████████████████████████████████████| 1.26G/1.26G [04:49<00:00, 4.68MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: ViT-B/32
text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: ViT-B/16


100%|███████████████████████████████████████| 335M/335M [00:28<00:00, 12.4MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: ViT-L/14
text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
Using cpu device and model: ViT-L/14@336px


100%|███████████████████████████████████████| 891M/891M [01:24<00:00, 11.0MiB/s]


text_t torch.int32, text_f torch.float32, image_t torch.float32, image_f torch.float32
