In [10]:
!pip install pytorch_transformers
!pip install open_clip_torch
!pip install einops

Collecting pytorch_transformers
  Downloading pytorch_transformers-1.2.0-py3-none-any.whl (176 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting boto3 (from pytorch_transformers)
  Obtaining dependency information for boto3 from https://files.pythonhosted.org/packages/f2/23/c5545cb57abfc3a9782287f2845a26286f6f9f7bcec36f13569567f950fe/boto3-1.29.5-py3-none-any.whl.metadata
  Downloading boto3-1.29.5-py3-none-any.whl.metadata (6.7 kB)
Collecting sacremoses (from pytorch_transformers)
  Obtaining dependency information for sacremoses from https://files.pythonhosted.org/packages/0b/f0/89ee2bc9da434bd78464f288fdb346bc2932f2ee80a90b2a4bbbac262c74/sacremoses-0.1.1-py3-none-any.whl.metadata
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Collecting botocore<1.33.0,>=1.32.5 (from boto3->pytorch_transformers)
  Obtaining dependency information for botocore<1.33.0,>=1.32.5 from https://files.py

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import os
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
import open_clip
from torch.utils.data import Dataset, DataLoader
import einops
import numpy as np

from eval import *
from data import ImageTextDataset
from models import *

[nltk_data] Downloading package wordnet to /Users/yzh/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

# choose the modality fuser here
model = TransformerFuser().to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-6)

train_data = DataLoader(ImageTextDataset('train', preprocess), batch_size=10, shuffle=True, num_workers=4)
valid_data = DataLoader(ImageTextDataset('valid', preprocess), batch_size=10, num_workers=4)
test_data = DataLoader(ImageTextDataset('test', preprocess), batch_size=10, num_workers=4)

num_epochs = 5

for epoch in range(num_epochs):

    for i, (prompt, retrieved_images, candidate_images, gold_index) in enumerate(train_data):

        b = len(prompt)

        logits = model(prompt, retrieved_images.to(device), candidate_images.to(device), clip, tokenizer, device)

        labels = torch.zeros((b, 10)).to(device)
        for j in range(b):
            labels[j, gold_index[j]] = 1 # [0, 0, ..., 1, ..., 0]

        model.train()
        optimizer.zero_grad()
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        
        if i % 200 == 0:
            mrr = MRR(logits.detach().cpu(), gold_index)
            hit_1 = hit_rate(logits.detach().cpu(), gold_index)
            print('[Epoch %d/%d] [Iter %d] [loss : %f] [hit@1 : %f] [mrr : %f]' %(epoch+1, num_epochs, i, loss.item(), hit_1, mrr))
            
    print("Evaluating on validation set...")
    evaluate(valid_data, model)
    print("Evaluating on test set...")
    evaluate(test_data, model)
    torch.save(model.state_dict(), "transformer_" + str(epoch+1) + ".pth")

In [None]:
# Load model and evaluate on test set

# CLIP baseline with augmented context
evaluate(test_data, clip, tokenizer, device, model=None)

# average fuser
ave_model = AverageFuser()
evaluate(test_data, clip, tokenizer, device, ave_model)

# transformer fuser
transformer_model = TransformerFuser()
transformer_model.load_state_dict(torch.load('transformer_4.pth'))
evaluate(test_data, clip, tokenizer, device, transformer_model.to(device))