In [None]:
# ViT VisionEncoder

from transformers import ViTModel

class VisionEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        self.base = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        d_in = self.base.config.hidden_size
        self.projection = Projection(d_in, d_out)
        
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        outputs = self.base(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0, :]
        projected_vec = self.projection(cls_token)
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

In [None]:
# Efficientnet VisionEncoder

from torchvision.models import efficientnet_b0

class VisionEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        self.base = efficientnet_b0(pretrained=True)
        d_in = self.base.classifier[1].in_features
        self.base.classifier = nn.Identity()

        self.projection = Projection(d_in, d_out)

        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        features = self.base(x)
        projected_vec = self.projection(features)
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

In [None]:
# GPT2 TextEncoder

from transformers import GPT2Model, GPT2Tokenizer

class TextEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        self.base = GPT2Model.from_pretrained("gpt2")
        d_in = self.base.config.hidden_size

        self.projection = Projection(d_in, d_out)

        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        out = self.base(input_ids=x)[0]  # pass only input_ids
        out = out[:, -1, :]  # get last token output
        projected_vec = self.projection(out)
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len
    
class Tokenizer:
    def __init__(self, tokenizer: GPT2Tokenizer) -> None:
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def __call__(self, x: str) -> GPT2Tokenizer:
        return self.tokenizer(
            x, truncation=True, padding=True, return_tensors="pt"
        )

class CustomModel(nn.Module):
    def __init__(self, lr: float = 1e-3) -> None:
        super().__init__()
        self.vision_encoder = VisionEncoder(EMBED_DIM)
        self.caption_encoder = TextEncoder(EMBED_DIM)
        self.tokenizer = Tokenizer(GPT2Tokenizer.from_pretrained("gpt2"))
        self.lr = lr
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def forward(self, images, text):
        text_tokens = self.tokenizer(text)
        text_input_ids = text_tokens["input_ids"].squeeze(1).to(self.device)  # Ensure correct shape

        image_embed = self.vision_encoder(images)
        caption_embed = self.caption_encoder(text_input_ids)

        similarity = caption_embed @ image_embed.T

        loss = self.CLIP_loss(similarity)
        img_acc, cap_acc = metrics(similarity)

        return loss, img_acc, cap_acc
    
    def CLIP_loss(self, logits: torch.Tensor) -> torch.Tensor:
        n = logits.shape[1]      # number of samples
        labels = torch.arange(n).to(self.device) # Create labels tensor
        # Calculate cross entropy losses along axis 0 and 1
        loss_i = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
        loss_t = F.cross_entropy(logits, labels, reduction="mean")
        # Calculate the final loss
        loss = (loss_i + loss_t) / 2

        return loss
    
    def top_image(self, images, text):
        text = self.tokenizer(text).to(self.device)
        caption_embed = self.caption_encoder(text["input_ids"])

        similarities = []

        for image in images:
            image_embed = self.vision_encoder(image.to(self.device))
            similarities.append(F.cosine_similarity(image_embed, caption_embed, dim=1).item())

        #top_image = np.argsort(similarities)[-1:][::-1]

        return similarities

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = bert_tokenizer.eos_token

def tokens(texts):
    return tokenizer(texts, truncation=True, padding=True)