In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from transformers import AutoModel, AutoTokenizer, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import pandas as pd
import dataset as wsd
import numpy as np

BATCH_SIZE = 32
EMBED_DIM = 512
TRANSFORMER_EMBED = 768
IMAGE_SIZE = 255

class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds

In [37]:
class VisionEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        base = models.resnet34(pretrained=True)
        d_in = base.fc.in_features
        base.fc = nn.Identity()
        self.base = base
        self.projection = Projection(d_in, d_out)
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

In [38]:
class TextEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        self.base = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
        self.projection = Projection(TRANSFORMER_EMBED, d_out)
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        out = self.base(x)[0]
        out = out[:, 0, :]  # get CLS token output
        projected_vec = self.projection(out)
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

In [39]:
class Tokenizer:
    def __init__(self, tokenizer: BertTokenizer) -> None:
        self.tokenizer = tokenizer

    def __call__(self, x: str) -> AutoTokenizer:
        return self.tokenizer(
            x, truncation=True, padding=True, return_tensors="pt"
        )

In [40]:
def metrics(similarity: torch.Tensor):
    y = torch.arange(len(similarity)).to(similarity.device)
    img2cap_match_idx = similarity.argmax(dim=1)
    cap2img_match_idx = similarity.argmax(dim=0)

    img_acc = (img2cap_match_idx == y).float().mean()
    cap_acc = (cap2img_match_idx == y).float().mean()

    return img_acc, cap_acc

In [41]:
class CustomModel(nn.Module):
    def __init__(self, lr: float = 1e-3) -> None:
        super().__init__()
        self.vision_encoder = VisionEncoder(EMBED_DIM)
        self.caption_encoder = TextEncoder(EMBED_DIM)
        self.tokenizer = Tokenizer(AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased"))
        self.lr = lr
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def forward(self, images, text):
        text = self.tokenizer(text).to(self.device)

        image_embed = self.vision_encoder(images)
        caption_embed = self.caption_encoder(text["input_ids"])

        similarity = caption_embed @ image_embed.T

        loss = self.CLIP_loss(similarity)
        img_acc, cap_acc = metrics(similarity)

        return loss, img_acc, cap_acc
    
    def CLIP_loss(self, logits: torch.Tensor) -> torch.Tensor:
        n = logits.shape[1]      # number of samples
        labels = torch.arange(n).to(self.device) # Create labels tensor
        # Calculate cross entropy losses along axis 0 and 1
        loss_i = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
        loss_t = F.cross_entropy(logits, labels, reduction="mean")
        # Calculate the final loss
        loss = (loss_i + loss_t) / 2

        return loss
    
    def top_image(self, images, text):
        text = self.tokenizer(text).to(self.device)
        caption_embed = self.caption_encoder(text["input_ids"])

        similarities = []

        for image in images:
            image_embed = self.vision_encoder(image)
            similarities.append(F.cosine_similarity(image_embed, caption_embed, dim=1).item())

        top_image = np.argsort(similarities)[-1:][::-1]

        return top_image

In [43]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CustomModel().to(device)

In [44]:
model.device

'cuda'

In [45]:
optimizer = torch.optim.Adam([
    {'params': model.vision_encoder.parameters()},
    {'params': model.caption_encoder.parameters()}
], lr=model.lr)

In [47]:
from torchvision import transforms as tt

scale = tt.Resize((IMAGE_SIZE, IMAGE_SIZE))
tensor = tt.ToTensor()
image_composed = tt.transforms.Compose([scale, tensor])

In [48]:
train_set = wsd.VisualWSDDataset(mode="train", image_transform=image_composed)
test_set = wsd.VisualWSDDataset(mode="test", image_transform=image_composed, test_lang='en')

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [49]:
start_epoch = 0
num_epochs = 1

batch_zero = True
for epoch in range(start_epoch, num_epochs):
    model.train()
    for batch in test_loader:
        image = batch["correct_img"].to(device)
        text = batch["label_context"]
        # images, text = batch
        loss, img_acc, cap_acc = model(image, text)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_zero:
          print(f"Epoch [{0}/{num_epochs}], Batch Loss: {loss.item()}")
          batch_zero = False


    # Print training statistics
    print(f"Epoch [{epoch+1}/{num_epochs}], Batch Loss: {loss.item()}")

print("Training complete.")

  return F.conv2d(input, weight, bias, self.stride,


Epoch [0/1], Batch Loss: 3.466111898422241
Epoch [1/1], Batch Loss: 2.694244384765625
Training complete.


In [33]:
test_loader = DataLoader(test_set, batch_size=1, shuffle=True)

for batch in test_loader:
  images = batch["imgs"]
  text = batch["label_context"]
  idx = model.top_image(images, text)
  print(idx[0], batch["correct_idx"].item())

7 3
7 0
1 7
9 1
0 2
4 6
5 2
1 4
2 5
2 5
1 4
2 0
9 2
5 9
0 9
5 4
1 6
0 2
1 3
4 6
7 6
8 9
7 8
0 1
4 0
7 2
8 5
6 9
1 0
1 1
2 2
5 2
3 3
3 9
9 1
7 8
7 6
1 6
1 0
9 9
6 7
0 9
9 9
5 9
4 3
8 9
8 9
3 1
4 8
1 6
4 9
5 1
8 2
3 3
7 6
4 2
5 5
0 6
6 0
8 1
2 3
9 8
8 2
7 2
4 0
1 3
0 4
8 8
4 3
5 5
2 4
8 0
1 6
3 4
1 2
4 5
5 2
7 4
8 7
7 6
1 0
3 3
2 5
3 0
3 7
0 6
0 0
5 2
3 4
4 5
6 1
2 2
3 9
6 7
1 0
1 5
1 0
6 4
6 8
3 8
3 0
4 4
3 1
0 1
1 9
4 7
6 9
6 5
0 7
4 3
1 2
8 0
4 8
6 3
3 3
7 2
2 0
3 0
3 2
3 7
4 3
9 9
0 4
0 4
5 8
4 6
1 2
3 2
5 8
9 0
7 6
6 9
3 4
8 8
1 6
6 8
9 7
9 5
0 6
6 3
6 3
1 4
9 3
9 9
3 1
9 0
6 4
3 3
5 4
4 6
5 0
7 3
9 3
4 6
1 3
6 9
4 2
6 2
4 2


KeyboardInterrupt: 

In [None]:
# TODO
# Trainer ausprobieren
# DataLoader anpassen mit Validation Dataset