In [3]:
import pandas as pd

df = pd.read_csv('full_meme_data.csv', names=['image', 'description', 'full_description', 'prois', 'meaning'], sep=',')[1:]

In [4]:
images = df['image'].tolist()
captions = df['description'].tolist()

# Initialize models and download data

In [5]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
language_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

In [6]:
# Load model directly
from transformers import AutoProcessor, CLIPModel

processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [7]:
tokenizer.add_special_tokens({"additional_special_tokens": ["<RET>"]})

language_model.resize_token_embeddings(len(tokenizer))

ret_token_id = tokenizer.encode("<RET>", add_special_tokens=False)[0]

In [8]:
language_model = language_model.eval()
clip_model = clip_model.eval()

for param in language_model.parameters():
    param.requires_grad = False

for param in clip_model.parameters():
    param.requires_grad = False

In [9]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
import urllib3


def download_image(url: str) -> urllib3.response.HTTPResponse:
    raw_image = url
    if url.startswith('http'):
        raw_image = requests.get(url, stream=True).raw
    return raw_image


def download_image_batch(image_urls):
    with ThreadPoolExecutor() as executor:
        futures = []
        for url in image_urls:
            futures.append(executor.submit(download_image, url))
    
        raw_images = []
        for future in as_completed(futures):
            raw_image = future.result()
            if raw_image is not None:
                raw_images.append(raw_image)
    return raw_images

In [10]:
import torch

def return_image_embeddings(images):
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt", padding=True).to(clip_model.device)
        embeds = clip_model.get_image_features(**inputs)
    return embeds

In [34]:
from tqdm import tqdm
import numpy as np
from PIL import Image

total_embeds = []

for image_batch in tqdm(np.array_split(images, len(images) // 128)):
    raw_images = download_image_batch(image_batch)
    opened_images = [Image.open(raw_image) for raw_image in raw_images]
    embeds = return_image_embeddings(opened_images)
    total_embeds.append(embeds)

100% 16/16 [01:18<00:00,  4.90s/it]


In [35]:
total_embeds = torch.cat(total_embeds)

In [36]:
# !mkdir fromage_weights
torch.save(total_embeds, 'fromage_weights/images_embeds.pt')

# total_embeds = torch.load('fromage_weights/images_embeds.pt')

In [38]:
total_embeds.shape

torch.Size([2124, 768])

# Dataset

In [206]:
from torch.utils.data import Dataset, DataLoader

class MultiModalDataset(Dataset):
    def __init__(self, input_ids, image_embeddings):
        self.input_ids = input_ids
        self.image_embeddings = image_embeddings

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, index):
        return {'input_ids': self.input_ids[index], 
                'image_embeddings': self.image_embeddings[index],
                'labels': torch.clone(self.input_ids[index]),
               }

In [207]:
input_text = [text + "<RET>" for text in captions]

In [208]:
input_ids = tokenizer(input_text, return_tensors='pt', padding=True)['input_ids']

threshold = int(0.1 * len(input_ids))

print(threshold)

train_dataset = MultiModalDataset(input_ids[:-threshold], total_embeds[:-threshold])
val_dataset = MultiModalDataset(input_ids[-threshold:], total_embeds[-threshold:])

len(train_dataset), len(val_dataset)

212


(1912, 212)

# Fromage retrieval

In [209]:
import torch

class FromageRetrieval(torch.nn.Module):
    def __init__(self, language_model, clip_model, ret_token_id: int):
        super().__init__()
        self.language_model = language_model
        self.clip_model = clip_model

        self.adapter_dim = 512
        self.token_emb_dim = language_model.config.word_embed_proj_dim
        self.visual_emb_dim = clip_model.config.projection_dim

        self.image_projection = torch.nn.Linear(self.visual_emb_dim, self.adapter_dim)
        self.text_projection = torch.nn.Linear(self.token_emb_dim, self.adapter_dim)

        self.ret_token_id = ret_token_id

    def forward(
        self, 
        input_ids: torch.tensor, 
        image_embeddings: torch.Tensor, 
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ):
        lm_out = self.language_model(
            input_ids=input_ids, 
            labels=labels, 
            attention_mask=attention_mask, 
            output_hidden_states=True
        )

        ret_token_indicies = torch.nonzero(input_ids == self.ret_token_id)[:, 1]

        hidden_state = lm_out.hidden_states[-1]

        ret_embs = torch.gather(hidden_state, 1, ret_token_indicies.unsqueeze(1).unsqueeze(2).repeat(1, 1, hidden_state.shape[-1]))

        projected_ret_embs = self.text_projection(ret_embs)
        projected_image_embs = self.image_projection(image_embeddings)

        return lm_out, projected_ret_embs, projected_image_embs

In [210]:
# fromage_retrieval_model = FromageRetrieval(language_model=language_model, clip_model=clip_model, ret_token_id=ret_token_id)

# string = ["Hello world! <RET>", 
#           "my name is <RET>", 
#           "i <RET>"]

# input_ids = tokenizer(string, return_tensors='pt', padding=True)['input_ids']

# image_embeddings = torch.rand(3, 768)

# lm_out, projected_ret_embs, projected_image_embs = fromage_retrieval_model(input_ids=input_ids, image_embeddings=image_embeddings)

# Custom Trainer

In [224]:
from transformers import Trainer
from info_nce_loss import InfoNCE

class CustomFromageTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        lm_out, projected_ret_embs, projected_image_embs = model(**inputs)
        cross_entropy_loss = lm_out.loss

        info_nce_loss_criterion = InfoNCE()

        projected_ret_embs = projected_ret_embs.squeeze(1)
        
        # print(projected_ret_embs.shape, projected_image_embs.shape)

        info_nce_loss = info_nce_loss_criterion(projected_ret_embs, projected_image_embs, None)

        loss = cross_entropy_loss + info_nce_loss
        return (loss, lm_out) if return_outputs else loss

In [225]:
# torch.rand(5, 1, 10).squeeze(1).shape

In [226]:
# !export WANDB_MODE=disable

In [228]:
from transformers import TrainingArguments

fromage_retrieval_model = FromageRetrieval(language_model=language_model, clip_model=clip_model, ret_token_id=ret_token_id)

training_arguments = TrainingArguments(
    dataloader_pin_memory=False,
    output_dir='.',
    num_train_epochs=5,
    logging_strategy='steps',
    logging_steps=64,
    eval_steps=64,
    do_eval=True,
    evaluation_strategy='steps',
)

custom_fromage_trainer = CustomFromageTrainer(
    model=fromage_retrieval_model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    args=training_arguments,
    tokenizer=tokenizer,
)

custom_fromage_trainer.train()

Step,Training Loss,Validation Loss
64,11.1656,11.23001
128,11.1306,11.229642


TrainOutput(global_step=150, training_loss=11.144424845377603, metrics={'train_runtime': 77.0552, 'train_samples_per_second': 124.067, 'train_steps_per_second': 1.947, 'total_flos': 0.0, 'train_loss': 11.144424845377603, 'epoch': 5.0})

In [168]:
custom_fromage_trainer.evaluate()

{}