In [11]:
import json
import torch
from PIL import Image
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader, TensorDataset

In [12]:
# Building the datset

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

with open('data/easy-VQA/easy_vqa/data/train/questions.json', 'r') as f:
    data = json.load(f)

image_paths = []
questions = []
answers = []

for item in data:
    img_path = f"data/easy-VQA/easy_vqa/data/train/images/{item[2]}.png"
    question = item[0]
    answer = item[1]

    image_paths.append(img_path)
    questions.append(question)
    answers.append(answer)

In [13]:
class EasyVQADataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, questions, answers, transform, text_encoder):
        self.image_paths = image_paths
        self.questions = questions
        self.answers = answers
        self.transform = transform
        self.text_encoder = text_encoder

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.transform(image)

        question = self.questions[idx]
        input_em = embedding_gen(question)
        pe = positional_encodings(vocab_size, embedding_dim)
        question_embedding = input_em + pe

        answer = self.answers[idx]
        input_ans = embedding_gen(answer)
        pe = positional_encodings(vocab_size, embedding_dim)
        answer_embedding = input_ans + pe

        return image, question_embedding, answer_embedding

def collate_fn(batch):
    images, questions, answers = zip(*batch)

    # Pad questions to same length
    padded_images = pad_sequence(images, batch_first=True)
    padded_questions = pad_sequence(questions, batch_first=True)
    padded_answers = pad_sequence(answers, batch_first=True)

    return padded_images, padded_questions, padded_answers

In [28]:
# creating the dataloader
batch_size = 1
train_dataset = EasyVQADataset(image_paths, questions, answers, transform, text_encoder=None)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [29]:
# creating the image encoder:
class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()

        self.embedding_dim = 16

        # Downscaling layers for Q, K, V
        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        # Upscaling back to embedding dim
        self.latent_upscale = nn.Linear(2, 16)

        # Layer norm
        self.layer_norm = nn.LayerNorm(16)

        # Feedforward block
        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        # Final projection
        self.output_proj = nn.Linear(16, 16)

    def forward(self, x):
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)

        # Residual + Norm
        x = self.layer_norm(context + x)

        # Feedforward + Norm
        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        return self.output_proj(out)

In [30]:
# defining the text and positional embedding generators

def embedding_gen(sentence):
    global words
    words = sentence.lower().split()
    global word_idx
    word_idx = {word: idx for idx, word in enumerate(words)}
    global embedding_dim
    embedding_dim = 16
    global vocab_size
    global idx_only
    idx_only = [i for i in range(len(word_idx))]
    vocab_size = len(word_idx)
    embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
    input_tensor = torch.LongTensor(idx_only)
    input_embeddings = embeddings(input_tensor)
    return input_embeddings

def positional_encodings(sequence_length, embedding_size):
    pe = torch.zeros(sequence_length, embedding_size)
    pos_encode = 0
    for pos in range(len(word_idx)):
        em_dim = embedding_dim
        for i in range(em_dim):
            if i%2 == 0:
                emma = torch.sin(torch.tensor(pos/(10000**((2*i)/embedding_dim))))
            else:
                emma = torch.cos(torch.tensor(pos/(10000**(((2*i)+1)/embedding_dim))))
            pe[pos][i] = emma
            #pe[pos][i+1] = emma2
    return pe

In [31]:
# creating the text encoder:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()

        self.embedding_dim = 16

        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        self.latent_upscale = nn.Linear(2, 16)
        self.layer_norm = nn.LayerNorm(16)

        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        self.output_proj = nn.Linear(16, 16)
    
    def forward(self, x):
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_k(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)

        # Residual + Norm
        x = self.layer_norm(context + x)

        # Feedforward + Norm
        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        return self.output_proj(out)

In [32]:
# defining the model:
class CLIPMini(nn.Module):
    def __init__(self):
        super(CLIPMini, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()

    def forward(self, image_patches, text_tokens):

        # Let's assume: image_patches: [B, 196, 16]
        cls_token_img = nn.Parameter(torch.randn(1, 1, 16)).to(image_patches)
        cls_token_img = cls_token_img.expand(image_patches.size(0), -1, -1)  # [B, 1, 16]

        img_input = torch.cat([cls_token_img, image_patches], dim=1)  # [B, 197, 16]
        img_embs = self.image_encoder(img_input)  # Encoder will attend to [CLS]

        cls_token_txt = nn.Parameter(torch.randn(1, 1, 16)).to(text_tokens.device)
        cls_token_txt = cls_token_txt.expand(text_tokens.size(0), -1, -1)  # [B, 1, 16]

        txt_input = torch.cat([cls_token_txt, text_tokens], dim=1)  # [B, seq_len+1, 16]
        txt_embs = self.text_encoder(txt_input)

        '''
        img_embs = self.image_encoder(image_patches)  # [B, 196, 16]
        txt_embs = self.text_encoder(text_tokens) # [B, seq_len, 16]
        '''

        # Pool
        #img_vec = torch.mean(img_embs, dim=1)      # [B, 16]
        img_vec = img_embs[:, 0, :]                # [B, 16] with CLS
        txt_vec = txt_embs[:, 0, :]                # [B, 16] with CLS

        # Normalize
        img_vec = F.normalize(img_vec, dim=-1)
        txt_vec = F.normalize(txt_vec, dim=-1)

        return img_vec, txt_vec

In [33]:
model = CLIPMini()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
pos_mat = nn.Parameter(torch.randn(batch_size, 196, 16))
layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(16, 16), stride=16)
layer = layer.to(device)
model = model.to(device)
criterion = nn.CrossEntropyLoss()

In [34]:
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0.0
    correct = 0
    total = 0

    model.train()

    for imgs, qs_em, ans_em in train_loader:
        imgs = imgs.to(device)
        qs_em = qs_em.to(device)
        ans_em = ans_em.to(device)
        pos_mat = pos_mat.to(device)

        optimizer.zero_grad()

        # 1. Image encoding
        img_em = layer(imgs).flatten(2).transpose(1, 2)  # [B, 196, 16]
        img_pass = img_em + pos_mat[:imgs.size(0)]       # Ensure correct batch slice
        #img_pass = img_em + pos_mat[:32]

        # 2. Text encoding
        image_vec, question_vec = model(img_pass, qs_em)
        _, ans_vec = model(img_pass, ans_em)

        # 3. Similarity + logits
        add_vec = image_vec + question_vec               # [B, D]
        logits = torch.matmul(add_vec, ans_vec.T)        # [B, B]

        # 4. Labels = diagonal elements (correct matches)
        labels = torch.arange(logits.size(0)).to(device)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        # 5. Stats tracking
        total_loss += loss.item()
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] — Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc*100:.2f}%")

Epoch [1/100] — Loss: 0.0000, Accuracy: 100.00%
Epoch [2/100] — Loss: 0.0000, Accuracy: 100.00%
Epoch [3/100] — Loss: 0.0000, Accuracy: 100.00%


KeyboardInterrupt: 

In [35]:
# Save the model's state_dict
torch.save(model.state_dict(), "clip_mini2.pth")
print("Model saved as clip_mini.pth")

Model saved as clip_mini.pth


In [43]:
img = Image.open('./data/easy-VQA/easy_vqa/data/test/images/1.png')
img = transform(img).unsqueeze(0)
layer = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 16, stride = 16)
img = layer(img).flatten(2).transpose(1, 2)  # [B, 196, 16]
pos_mat = nn.Parameter(torch.randn(1, 196, 16))
x = img + pos_mat
x = x.to(device)
x.shape

torch.Size([1, 196, 16])

In [37]:
f = open("./data/easy-VQA/easy_vqa/data/answers.txt")
answer_ems = []
answer_ems2 = []
for word in f.readlines():
    answer_ems2.append(word)
    emb = embedding_gen(word)
    pe = positional_encodings(vocab_size, embedding_dim)
    encoded = (emb + pe).to(device)
    encoded = encoded.unsqueeze(0).to(device)
    _, a_vec = model(x, encoded)
    answer_ems.append(a_vec)

answer_vec = torch.cat(answer_ems, dim=0)

len(answer_vec)

13

In [56]:
question = input("Enter question:")
question_embeddings = embedding_gen(question)
pe = positional_encodings(vocab_size, embedding_dim)
y = question_embeddings + pe
y = y.unsqueeze(0)

In [57]:
x = x.to(device)
y = y.to(device)

In [58]:
img_vec, ques_vec = model(x, y)
combined_vec = img_vec + ques_vec  # [B, dim]
logits = torch.matmul(combined_vec, answer_vec.T)  # [B, 13]
predicted = torch.argmax(logits, dim=1)
#predicted_answers = [idx_to_answer[idx.item()] for idx in predicted]

In [59]:
print(answer_ems2[predicted])

blue

