In [1]:
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()

        self.embedding_dim = 16

        # Downscaling layers for Q, K, V
        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        # Upscaling back to embedding dim
        self.latent_upscale = nn.Linear(2, 16)

        # Layer norm
        self.layer_norm = nn.LayerNorm(16)

        # Feedforward block
        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        # Final projection
        self.output_proj = nn.Linear(16, 16)

    def forward(self, x):
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)

        # Residual + Norm
        x = self.layer_norm(context + x)

        # Feedforward + Norm
        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        return self.output_proj(out)

In [22]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

img = Image.open('/home/lcamk2/idiot_programmer/kitchen-countertops-options.jpg')

img = transform(img)

layer = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 16, stride = 16)

img = layer(img)
img = img.unsqueeze(0)       
img = img.permute(0, 2, 3, 1)
img = img.view(1, 196, 16) 

pos_mat = nn.Parameter(torch.randn(1, 196, 16))
x = img + pos_mat

image_embeddings = ImageEncoder()
encoded_img = image_embeddings(x)
encoded_img.shape

torch.Size([1, 196, 16])

In [9]:
def embedding_gen(sentence):
    global words
    words = sentence.lower().split()
    global word_idx
    word_idx = {word: idx for idx, word in enumerate(words)}
    global embedding_dim
    embedding_dim = 16
    global vocab_size
    global idx_only
    idx_only = [i for i in range(len(word_idx))]
    vocab_size = len(word_idx)
    embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
    input_tensor = torch.LongTensor(idx_only)
    input_embeddings = embeddings(input_tensor)
    return input_embeddings

def positional_encodings(sequence_length, embedding_size):
    pe = torch.zeros(sequence_length, embedding_size)
    pos_encode = 0
    for pos in range(len(word_idx)):
        em_dim = embedding_dim
        for i in range(em_dim):
            if i%2 == 0:
                emma = torch.sin(torch.tensor(pos/(10000**((2*i)/embedding_dim))))
            else:
                emma = torch.cos(torch.tensor(pos/(10000**(((2*i)+1)/embedding_dim))))
            pe[pos][i] = emma
            #pe[pos][i+1] = emma2
    return pe

In [10]:
softmax = nn.Softmax(-1)

In [16]:
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()

        self.embedding_dim = 16

        self.w_q = nn.Linear(16, 2)
        self.w_k = nn.Linear(16, 2)
        self.w_v = nn.Linear(16, 2)

        self.latent_upscale = nn.Linear(2, 16)
        self.layer_norm = nn.LayerNorm(16)

        self.feed_fwd = nn.Sequential(
            nn.Linear(16, 16),
            nn.Linear(16, 16),
            nn.Linear(16, 16)
        )

        self.output_proj = nn.Linear(16, 16)
    
    def forward(self, x):
        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_k(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        context = self.latent_upscale(context)

        # Residual + Norm
        x = self.layer_norm(context + x)

        # Feedforward + Norm
        ff_out = self.feed_fwd(x)
        out = self.layer_norm(ff_out + x)

        # Final linear (optional)
        return self.output_proj(out)

In [23]:
input_embeddings = embedding_gen("Joy Maa!")
pe = positional_encodings(vocab_size, embedding_dim)

y = input_embeddings + pe

text_encoder = TextEncoder()
y = y.unsqueeze(0)
encoded_text = text_encoder(y)
encoded_text.shape

torch.Size([1, 2, 16])

In [29]:
'''
class CLIPMini(nn.Module):
    def __init__(self):
        super(CLIPMini, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()

    def forward(self, image_patches, text_tokens):
        img_embs = self.image_encoder(image_patches)  # [B, 196, 16]
        txt_embs = self.text_encoder(text_tokens) # [B, seq_len, 16]

        # Pool
        img_vec = torch.mean(img_embs, dim=1)      # [B, 16]
        txt_vec = txt_embs[:, 0, :]                # [B, 16] - CLS

        # Normalize
        img_vec = F.normalize(img_vec, dim=-1)
        txt_vec = F.normalize(txt_vec, dim=-1)

        return img_vec, txt_vec
'''

In [37]:
class CLIPMini(nn.Module):
    def __init__(self):
        super(CLIPMini, self).__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()

    def forward(self, image_patches, text_tokens):

        # Let's assume: image_patches: [B, 196, 16]
        cls_token_img = nn.Parameter(torch.randn(1, 1, 16)).to(image_patches)
        cls_token_img = cls_token_img.expand(image_patches.size(0), -1, -1)  # [B, 1, 16]

        img_input = torch.cat([cls_token_img, image_patches], dim=1)  # [B, 197, 16]
        img_embs = self.image_encoder(img_input)  # Encoder will attend to [CLS]

        cls_token_txt = nn.Parameter(torch.randn(1, 1, 16)).to(text_tokens.device)
        cls_token_txt = cls_token_txt.expand(text_tokens.size(0), -1, -1)  # [B, 1, 16]

        txt_input = torch.cat([cls_token_txt, text_tokens], dim=1)  # [B, seq_len+1, 16]
        txt_embs = self.text_encoder(txt_input)

        '''
        img_embs = self.image_encoder(image_patches)  # [B, 196, 16]
        txt_embs = self.text_encoder(text_tokens) # [B, seq_len, 16]
        '''

        # Pool
        #img_vec = torch.mean(img_embs, dim=1)      # [B, 16]
        img_vec = img_embs[:, 0, :]                # [B, 16] with CLS
        txt_vec = txt_embs[:, 0, :]                # [B, 16] with CLS

        # Normalize
        img_vec = F.normalize(img_vec, dim=-1)
        txt_vec = F.normalize(txt_vec, dim=-1)

        return img_vec, txt_vec

In [38]:
model = CLIPMini()
i_vec, t_vec = model(encoded_img, encoded_text)

In [44]:
import json
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset

# Define a transform for the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load the JSON
with open('data/easy-VQA/easy_vqa/data/train/questions.json', 'r') as f:
    data = json.load(f)

image_paths = []
questions = []
answers = []

for item in data:
    img_path = f"data/easy-VQA/easy_vqa/data/train/images/{item[2]}.png"
    question = item[0]
    answer = item[1]

    image_paths.append(img_path)
    questions.append(question)
    answers.append(answer)

In [117]:
from torch.utils.data import Dataset

from torch.nn.utils.rnn import pad_sequence

class EasyVQADataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, questions, answers, transform, text_encoder):
        self.image_paths = image_paths
        self.questions = questions
        self.answers = answers
        self.transform = transform
        self.text_encoder = text_encoder

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = self.transform(image)

        question = self.questions[idx]
        input_em = embedding_gen(question)
        pe = positional_encodings(vocab_size, embedding_dim)
        question_embedding = input_em + pe

        answer = self.answers[idx]
        input_ans = embedding_gen(answer)
        pe = positional_encodings(vocab_size, embedding_dim)
        answer_embedding = input_ans + pe

        return image, question_embedding, answer_embedding

def collate_fn(batch):
    images, questions, answers = zip(*batch)
    #images = list(images)  # [B, 224, 224]

    # Pad questions to same length
    padded_images = pad_sequence(images, batch_first=True)
    padded_questions = pad_sequence(questions, batch_first=True)  # [B, max_seq_len, 16]
    padded_answers = pad_sequence(answers, batch_first=True)  # [B, max_seq_len, 16]

    return padded_images, padded_questions, padded_answers

In [118]:
dataset = EasyVQADataset(image_paths, questions, answers, transform, text_encoder=None)
batch_size = 8
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [133]:
for imgs, qs, ans in loader:
    print(imgs)  # [B, 3, 224, 224]
    print(qs)          # List of strings
    print(ans)         # List of strings
    break

tensor([[[[0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
          [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
          [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
          ...,
          [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
          [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
          [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765]],

         [[0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176],
          [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176],
          [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176],
          ...,
          [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176],
          [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176],
          [0.9176, 0.9176, 0.9176,  ..., 0.9176, 0.9176, 0.9176]],

         [[0.9804, 0.9804, 0.9804,  ..., 0.9804, 0.9804, 0.9804],
          [0.9804, 0.9804, 0.9804,  ..., 0.9804, 0.9804, 0.9804],
          [0.9804, 0.9804, 0.9804,  ..., 0

In [134]:
imgs.shape

torch.Size([8, 3, 224, 224])

In [135]:
img_bed = layer(imgs)
img_bed.shape

torch.Size([8, 16, 14, 14])

In [98]:
len(imgs)

8

In [138]:
layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(16, 16), stride=16)
#patches = embed(images).flatten(2).transpose(1, 2)

pos_mat = nn.Parameter(torch.randn(batch_size, 196, 16))
model = CLIPMini()

for epoch in range(100):
    for imgs, qs_em, ans_em in loader:
        img_em = layer(imgs).flatten(2).transpose(1, 2)
        #img_em = img_em.permute(0, 2, 1)
        #img_em = img_em.view(batch_size, 196, 16)
        img_pass = img_em + pos_mat

        text_encoder = TextEncoder()

        image_vec, question_vec = model(img_pass, qs_em)
        _, ans_vec = model(img_pass, ans_em)

        add_vec = image_vec + question_vec
        dot_pro = torch.matmul(add_vec, ans_vec.T)

        print(dot_pro.shape)
        break
    break

torch.Size([8, 8])


In [122]:
img_em.shape

torch.Size([16, 196, 16])

In [None]:
imgs.shape

torch.Size([16, 3, 224, 224])

In [None]:
'''
img_em = layer(imgs)
img_em = img_em.permute(0, 2, 3, 1)
img_em = img_em.view(batch_size, 196, 16)

hoo = ImageEncoder()
img_em = hoo(img_em)
img_em.shape
'''

torch.Size([8, 196, 16])

In [None]:
text_encoder = TextEncoder()
encoded_q = text_encoder(qs)
encoded_q.shape

torch.Size([8, 7, 16])

In [None]:
encoded_a = text_encoder(ans)
encoded_a.shape

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple

In [None]:
torch.tensor(question_embeddings).shape

ValueError: only one element tensors can be converted to Python scalars

In [None]:
y.shape

torch.Size([6, 16])

In [None]:
input_embeddings = embedding_gen("Joy Maa!")
pe = positional_encodings(vocab_size, embedding_dim)

y = input_embeddings + pe

text_encoder = TextEncoder()
y = y.unsqueeze(0)
encoded_text = text_encoder(y)
encoded_text.shape

In [None]:
type(y)

torch.Tensor