In [None]:
import torch
from torch.nn import functional as nnf
from torch.utils.data import  DataLoader
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import sys
import argparse
import pickle
import json
from typing import  Union
from dataset import ClipGPTFlickr8kDataset
from models import ClipCaptionModel, ClipCaptionPrefix, MappingType
from args import DemoArgs
from bleu import belu_score


args = DemoArgs()
lr = 2e-5
warmup_steps = 5000
output_dir= args.out_dir
output_prefix = args.output_prefix
start_epoch = 0
device = torch.device('cuda:0')
batch_size = args.bs
epochs = args.epochs
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
model = ClipCaptionPrefix(args.prefix_length, lang=args.lang , clip_length=args.prefix_length_clip, prefix_size=args.prefix_dim, num_layers=args.num_layers, mapping_type=args.mapping_type)
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=lr)
dataset = ClipGPTFlickr8kDataset(args.data, args.prefix_length,lang= args.lang, normalize_prefix=args.normalize_prefix)
train_dataloader = DataLoader(dataset, batch_size=5, shuffle=False, drop_last=True)
dataset2 = ClipGPTFlickr8kDataset('./data/embeddings/arabic_CLIP-ViT-B-32_embeddings.pkl', args.prefix_length,lang= 'arabic', normalize_prefix=args.normalize_prefix)
train_dataloader2 = DataLoader(dataset2, batch_size=3, shuffle=False, drop_last=True)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader))
for epoch in range(start_epoch, epochs+start_epoch):
    ietr_obj = iter(train_dataloader)
    ietr_obj2 = iter(train_dataloader2)
    print(f">>> Training epoch {epoch} out of {epochs+start_epoch}")
    sys.stdout.flush()
    number_of_batches = len(train_dataloader)
    progress = tqdm(total=len(train_dataloader), desc=output_prefix)
    for idx in range(number_of_batches):
        tokens, mask, prefix = next(ietr_obj)
        tokens= tokens[0:3]
        mask = mask[0:3]
        prefix = prefix[0:3]
        tokens2, _, prefix2 = next(ietr_obj2)
        model.zero_grad()
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        tokens2 = tokens2.to(device)
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, dataset.prefix_length - 1: -1]
        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens2.flatten(), ignore_index=0)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress.set_postfix({"loss": loss.item()})
        progress.update()
    progress.close()
    if epoch % args.save_every == 0 or epoch == epochs - 1 + start_epoch:
        model_path = os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt")
        args_path = model_path.replace('.pt', '_args.pkl')
        torch.save(model.state_dict(), model_path)
        with open(args_path, 'wb') as f:
            pickle.dump(args, f)



In [1]:
from torch.nn import functional as nnf
from torch.utils.data import  DataLoader
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import sys
import argparse
import pickle
import json
from typing import  Union
from dataset import ClipGPTFlickr8kDataset
from models import ClipCaptionModel, ClipCaptionPrefix, MappingType
from args import DemoArgs
from bleu import belu_score
import torch

device = torch.device('cuda:0')

def load_model(model_path):
    '''load model from path'''
    epoch_number = int(model_path.split('-')[-1].split('.')[0]) + 1
    args_path = model_path.replace('.pt', '_args.pkl')
    with open(args_path, 'rb') as f:
        args = pickle.load(f)
    model = ClipCaptionPrefix(args.prefix_length, args.lang , clip_length=args.prefix_length_clip, prefix_size=512, num_layers=args.num_layers, mapping_type=args.mapping_type)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    return model , args, epoch_number    

In [2]:
model_path = './checkpoints/english_exp_2-000.pt'
model, args, start_epoch = load_model(model_path)
model = model.to(device)

In [3]:
dataset = ClipGPTFlickr8kDataset(args.data, args.prefix_length,lang= args.lang, normalize_prefix=args.normalize_prefix)
train_dataloader = DataLoader(dataset, batch_size=5, shuffle=False, drop_last=True)

dataset2 = ClipGPTFlickr8kDataset('./data/embeddings/arabic_CLIP-ViT-B-32_embeddings.pkl', args.prefix_length,lang= 'arabic', normalize_prefix=args.normalize_prefix)
tokenizer = dataset2.tokenizer


Data size is 40455
Data size is 24273


In [None]:
ietr_obj = iter(train_dataloader)
with torch.no_grad():
    for i in range(10):
        tokens, mask, prefix = next(ietr_obj)
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, dataset.prefix_length - 1: -1]
        preds = torch.argmax(logits, dim=-1).tolist()

        for pred in preds:
            print(tokenizer.decode(pred))
            print('-----------------')

In [None]:
import clip
import torch
import pickle 
import random
import numpy as np
import pandas as pd
import transformers
from tqdm import tqdm
from bleu import generate_caption
from inference_gpt import load_model
from multilingual_clip import pt_multilingual_clip
from transformers import AutoTokenizer, GPT2Tokenizer

# select random 10 images
random.seed(42)

# model_path = './checkpoints/english_exp_1-029.pt'
model_path = './checkpoints/arabic_exp_2-045.pt'

k = 10

if 'english' in model_path:
    lang = 'english'
if 'arabic' in model_path:
    lang = 'arabic'

data_path = f'./data/embeddings/{lang}_CLIP-ViT-B-32_embeddings.pkl'
with open(data_path, 'rb') as f:
    data = pickle.load(f)

image_ids = [data['captions'][i]['image_id'] for i in range(len(data['captions']))]
unique_image_ids = np.unique(image_ids, return_index=True)[1]
n = random.sample(list(unique_image_ids), k)

device = 'cuda' if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
logit_scale = clip_model.logit_scale.exp().float().to('cpu')

multilingual_clip_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-32')
multilingual_tokenizer = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-32')
multilingual_tokenizer.pad_token = multilingual_tokenizer.eos_token


pretrained_model, prefix_length = load_model(model_path)
pretrained_model = pretrained_model.to(device)
if lang == 'arabic':
    pretrained_tokenizer = AutoTokenizer.from_pretrained("akhooli/gpt2-small-arabic")
if lang == 'english':
    pretrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
pretrained_tokenizer.pad_token = pretrained_tokenizer.eos_token


image_embeddings = data['clip_embedding'][n].float()
sample_captions = [data['captions'][i]['caption'] for i in n]
image_ids = [data['captions'][i]['image_id'] for i in n]

predictions = []
for i in tqdm(range(len(image_ids))):
    image_path = f'./data/images/{image_ids[i]}'
    prediction = generate_caption(image_path, pretrained_model ,preprocess, clip_model, pretrained_tokenizer, prefix_length, device)
    predictions.append(prediction)

with torch.no_grad():
    text_embeddings = multilingual_clip_model.forward(sample_captions, multilingual_tokenizer).float()
    predicted_embeddings = multilingual_clip_model.forward(predictions, pretrained_tokenizer).float()

    image_features = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    text_features  = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    predicted_text_features  = predicted_embeddings / predicted_embeddings.norm(dim=-1, keepdim=True)


    true_similarities = []
    pred_similarities = []
    true_captions = []
    pred_captions = []
    for i in range(len(image_ids)):
        true_similarity = logit_scale * (image_features[i]* text_features[i]).sum()
        pred_similarity = logit_scale * (image_features[i]* predicted_text_features[i]).sum()
        true_similarities.append(true_similarity.item())
        pred_similarities.append(pred_similarity.item())
        true_captions.append(sample_captions[i])
        pred_captions.append(predictions[i])

    df = pd.DataFrame({
        'image_id': image_ids,
        'true_similarity': true_similarities,
        'predicted_similarity': pred_similarities,
        'true_caption': true_captions,
        'predicted_caption': pred_captions
        })
df

In [None]:
####### Run Inference On Sample Images ################
import inference_gpt
ckpt_path = './checkpoints/arabic_exp_2-045.pt'
inference_gpt.main(ckpt_path)

In [None]:
from bleu import belu_score

model_path = './checkpoints/arabic_exp_2-045.pt'
belu_score(model_path)
model_path = './checkpoints/english_exp_1-029.pt'
belu_score(model_path)