In [None]:
import os
import sys
import clip
import torch
import pickle
import PIL.Image 
import skimage.io as io
import matplotlib.pyplot as plt
import torch.nn.functional as nnf
from plotting import fix_arabic_text
from models import  ClipCaptionPrefix
from transformers import AutoTokenizer, GPT2Tokenizer




def beam_search(model, tokenizer, embed, entry_length=20, top_p=0.8, temperature=1., stop_token= '.'):
    '''Beam search for the GPT model.'''
    
    model.eval()
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    generated = embed
    tokens = None
    
    with torch.no_grad():
        for i in range(entry_length):
            
            #  get the logits for the next token
            outputs = model.gpt(inputs_embeds=generated)
            print(f'shape of outputs: {outputs.shape}')
            logits = outputs.logits
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[:, indices_to_remove] = filter_value
            
            #  take the most likely token and add it to the sequence
            next_token = torch.argmax(logits, -1).unsqueeze(0)


            # transform the token to embedding
            next_token_embed = model.gpt.transformer.wte(next_token)

            # add the token to the sequence
            if tokens is None:
                tokens = next_token
            else:
                tokens = torch.cat((tokens, next_token), dim=1)
            
            # add the embedding to the sequence
            generated = torch.cat((generated, next_token_embed), dim=1)

            # stop if the stop token is reached
            if stop_token_index == next_token.item():
                break
            if stop_token == tokenizer.decode(tokens.squeeze().cpu().numpy())[-1]:
                break

        # convert the sequence to text
        output_list = list(tokens.squeeze().cpu().numpy())
        output_text = tokenizer.decode(output_list)
        generated_list.append(output_text)

    return generated_list[0]


def generate_caption(image_path, model, preprocess, clip_model, tokenizer ,prefix_length,  lang ,device):
    


def main(model_path):
    #Read the language from the model path
    if 'arabic' in model_path:
        lang = 'arabic'
    if 'english' in model_path:
        lang = 'english'
    print(f'The Lang is {lang}')
    # Load the CLIP model
    device = 'cuda' if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
    
    # Load the GPT model Tokenizer
    if lang == 'arabic':
        tokenizer = AutoTokenizer.from_pretrained("akhooli/gpt2-small-arabic")
    if lang == 'english':
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Load the GPT model
    model, prefix_length = load_model(model_path)
    model.eval()
    model = model.to(device)


    sample_images_dir = './sample_image'
    sample_images_paths = [os.path.join(sample_images_dir, image_name) for image_name in os.listdir(sample_images_dir)]
    for image_path in sample_images_paths:
        generate_caption(image_path, model,preprocess, clip_model, tokenizer, prefix_length, lang, device)


if __name__ == '__main__':
    #Read the model path from the command line
    main(model_path)


In [None]:
import os
import sys
import clip
import torch
import pickle
import PIL.Image 
import skimage.io as io
import matplotlib.pyplot as plt
import torch.nn.functional as nnf
from plotting import fix_arabic_text
from models import  ClipCaptionPrefix
from transformers import AutoTokenizer, GPT2Tokenizer

def load_model(model_path):
    '''load model from path'''
    args_path = model_path.replace('.pt', '_args.pkl')
    with open(args_path, 'rb') as f:
        args = pickle.load(f)
    model = ClipCaptionPrefix(
        prefix_length=args.prefix_length,
        lang = args.lang ,
        clip_length=args.prefix_length_clip,
        prefix_size=512,
        num_layers=args.num_layers,
        mapping_type=args.mapping_type)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    return model , args.prefix_length



model_path = './checkpoints/english_exp_1-029.pt'
#Read the language from the model path
if 'arabic' in model_path:
    lang = 'arabic'
if 'english' in model_path:
    lang = 'english'
print(f'The Lang is {lang}')
# Load the CLIP model
device = 'cuda' if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# Load the GPT model Tokenizer
if lang == 'arabic':
    tokenizer = AutoTokenizer.from_pretrained("akhooli/gpt2-small-arabic")
if lang == 'english':
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Load the GPT model
model, prefix_length = load_model(model_path)
model.eval()
model = model.to(device)


sample_images_dir = './sample_image'
sample_images_paths = [os.path.join(sample_images_dir, image_name) for image_name in os.listdir(sample_images_dir)]
image_path = sample_images_paths[0]



In [None]:
image = io.imread(image_path)
pil_image = PIL.Image.fromarray(image)
image = preprocess(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
    prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
    prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)


In [None]:
# def generate_beam(
#     model,
#     tokenizer,
import numpy as np
beam_size: int = 5
embed= prefix_embed
prompt=None
entry_length=67
temperature=1.0
stop_token: str = "."


model.eval()
stop_token_index = tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(model.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
    if embed is not None:
        generated = embed
    else:
        if tokens is None:
            tokens = torch.tensor(tokenizer.encode(prompt))
            tokens = tokens.unsqueeze(0).to(device)
            generated = model.gpt.transformer.wte(tokens)
    for i in range(entry_length):
        outputs = model.gpt(inputs_embeds=generated)
        print(f'generated shape is {generated.shape}')
        logits = outputs.logits
        print(f'logits shape is {logits.shape}')
        logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
        logits = logits.softmax(-1).log()
        if scores is None:
            scores, next_tokens = logits.topk(beam_size, -1)
            generated = generated.expand(beam_size, *generated.shape[1:])
            next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
            if tokens is None:
                tokens = next_tokens
            else:
                tokens = tokens.expand(beam_size, *tokens.shape[1:])
                tokens = torch.cat((tokens, next_tokens), dim=1)
        else:
            logits[is_stopped] = -float(np.inf)
            logits[is_stopped, 0] = 0
            scores_sum = scores[:, None] + logits
            seq_lengths[~is_stopped] += 1
            scores_sum_average = scores_sum / seq_lengths[:, None]
            scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
                beam_size, -1
            )
            next_tokens_source = next_tokens // scores_sum.shape[1]
            seq_lengths = seq_lengths[next_tokens_source]
            next_tokens = next_tokens % scores_sum.shape[1]
            next_tokens = next_tokens.unsqueeze(1)
            tokens = tokens[next_tokens_source]
            tokens = torch.cat((tokens, next_tokens), dim=1)
            generated = generated[next_tokens_source]
            scores = scores_sum_average * seq_lengths
            is_stopped = is_stopped[next_tokens_source]
        next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
            generated.shape[0], 1, -1
        )
        generated = torch.cat((generated, next_token_embed), dim=1)
        is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
        if is_stopped.all():
            break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [
    tokenizer.decode(output[: int(length)])
    for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]



In [None]:
len(output_texts)

In [None]:
#print all outputs
all_lines = 0
for output in output_texts:
    print(output)
    #number of lines in the output
    print(len(output.splitlines()))
    #sum of the length of all lines
    all_lines += len(output.splitlines())
print(all_lines)


In [60]:
entry_length=20
top_p= 0.8
temperature=1.
stop_token= '.'
embed=prefix_embed

with torch.no_grad():
    model.eval()
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    generated = embed
    tokens = None
    
    with torch.no_grad():
        for i in range(entry_length):
            
            #  get the logits for the next token
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits
            print(f'logits shape is {logits.shape}')
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            print(f'logits after temperature shape is {logits.shape}')
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            print(f'sorted_logits shape is {sorted_logits.shape}')
            print(f'sorted_indices shape is {sorted_indices.shape}')
            cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
            print(f'cumulative_probs shape is {cumulative_probs.shape}')
            sorted_indices_to_remove = cumulative_probs > top_p
            print(f'sorted_indices_to_remove shape is {sorted_indices_to_remove.shape}')
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            print(f'sorted_indices_to_remove after clone shape is {sorted_indices_to_remove.shape}')
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            print(f'indices_to_remove shape is {indices_to_remove.shape}')
            logits[:, indices_to_remove] = filter_value
            
            print(f'number of not removed tokens is {torch.sum(logits != filter_value)}')
            #  take the most likely token and add it to the sequence
            next_token = torch.argmax(logits, -1).unsqueeze(0)

            
            # transform the token to embedding
            next_token_embed = model.gpt.transformer.wte(next_token)

            # add the token to the sequence
            if tokens is None:
                tokens = next_token
            else:
                tokens = torch.cat((tokens, next_token), dim=1)
            
            print(f'tokens at the end of the loop shape is {tokens.shape}')
            
            # add the embedding to the sequence
            generated = torch.cat((generated, next_token_embed), dim=1)

            # stop if the stop token is reached
            if stop_token_index == next_token.item():
                break
            if stop_token == tokenizer.decode(tokens.squeeze().cpu().numpy())[-1]:
                break

        # convert the sequence to text
        # output_list = list(tokens.squeeze().cpu().numpy())
        # output_text = tokenizer.decode(output_list)
        # generated_list.append(output_text)




logits shape is torch.Size([1, 10, 50257])
logits after temperature shape is torch.Size([1, 50257])
sorted_logits shape is torch.Size([1, 50257])
sorted_indices shape is torch.Size([1, 50257])
cumulative_probs shape is torch.Size([1, 50257])
sorted_indices_to_remove shape is torch.Size([1, 50257])
sorted_indices_to_remove after clone shape is torch.Size([1, 50257])
indices_to_remove shape is torch.Size([50256])
number of not removed tokens is 1
tokens at the end of the loop shape is torch.Size([1, 1])
logits shape is torch.Size([1, 11, 50257])
logits after temperature shape is torch.Size([1, 50257])
sorted_logits shape is torch.Size([1, 50257])
sorted_indices shape is torch.Size([1, 50257])
cumulative_probs shape is torch.Size([1, 50257])
sorted_indices_to_remove shape is torch.Size([1, 50257])
sorted_indices_to_remove after clone shape is torch.Size([1, 50257])
indices_to_remove shape is torch.Size([50255])
number of not removed tokens is 2
tokens at the end of the loop shape is torch

In [None]:
next_token

In [None]:
logits[0][32]

In [None]:
sorted_indices

In [None]:
indices_to_remove

In [None]:
sorted_indices_to_remove

In [None]:
sorted_indices_to_remove.sum()

In [None]:
cumulative_probs

In [None]:

#display pil_image using plt
plt.imshow(pil_image)
plt.axis('off')
print(generated_text_prefix)
plt.title(generated_text_prefix)
plt.show()

In [None]:
import torch
from torch.nn import functional as nnf
from torch.utils.data import  DataLoader
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import sys
import argparse
import pickle
import json
from typing import  Union
from dataset import ClipGPTFlickr8kDataset
from models import ClipCaptionModel, ClipCaptionPrefix, MappingType
from args import DemoArgs
from bleu import belu_score


args = DemoArgs()
lr = 2e-5
warmup_steps = 5000
output_dir= args.out_dir
output_prefix = args.output_prefix
start_epoch = 0
device = torch.device('cuda:0')
batch_size = args.bs
epochs = args.epochs
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
model = ClipCaptionPrefix(args.prefix_length, lang=args.lang , clip_length=args.prefix_length_clip, prefix_size=args.prefix_dim, num_layers=args.num_layers, mapping_type=args.mapping_type)
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=lr)
dataset = ClipGPTFlickr8kDataset(args.data, args.prefix_length,lang= args.lang, normalize_prefix=args.normalize_prefix)
train_dataloader = DataLoader(dataset, batch_size=5, shuffle=False, drop_last=True)
dataset2 = ClipGPTFlickr8kDataset('./data/embeddings/arabic_CLIP-ViT-B-32_embeddings.pkl', args.prefix_length,lang= 'arabic', normalize_prefix=args.normalize_prefix)
train_dataloader2 = DataLoader(dataset2, batch_size=3, shuffle=False, drop_last=True)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader))
for epoch in range(start_epoch, epochs+start_epoch):
    ietr_obj = iter(train_dataloader)
    ietr_obj2 = iter(train_dataloader2)
    print(f">>> Training epoch {epoch} out of {epochs+start_epoch}")
    sys.stdout.flush()
    number_of_batches = len(train_dataloader)
    progress = tqdm(total=len(train_dataloader), desc=output_prefix)
    for idx in range(number_of_batches):
        tokens, mask, prefix = next(ietr_obj)
        tokens= tokens[0:3]
        mask = mask[0:3]
        prefix = prefix[0:3]
        tokens2, _, prefix2 = next(ietr_obj2)
        model.zero_grad()
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        tokens2 = tokens2.to(device)
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, dataset.prefix_length - 1: -1]
        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens2.flatten(), ignore_index=0)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress.set_postfix({"loss": loss.item()})
        progress.update()
    progress.close()
    if epoch % args.save_every == 0 or epoch == epochs - 1 + start_epoch:
        model_path = os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt")
        args_path = model_path.replace('.pt', '_args.pkl')
        torch.save(model.state_dict(), model_path)
        with open(args_path, 'wb') as f:
            pickle.dump(args, f)



In [None]:
from torch.nn import functional as nnf
from torch.utils.data import  DataLoader
from transformers import  AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import sys
import argparse
import pickle
import json
from typing import  Union
from dataset import ClipGPTFlickr8kDataset
from models import ClipCaptionModel, ClipCaptionPrefix, MappingType
from args import DemoArgs
from bleu import belu_score
import torch

device = torch.device('cuda:0')

def load_model(model_path):
    '''load model from path'''
    epoch_number = int(model_path.split('-')[-1].split('.')[0]) + 1
    args_path = model_path.replace('.pt', '_args.pkl')
    with open(args_path, 'rb') as f:
        args = pickle.load(f)
    model = ClipCaptionPrefix(args.prefix_length, args.lang , clip_length=args.prefix_length_clip, prefix_size=512, num_layers=args.num_layers, mapping_type=args.mapping_type)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    return model , args, epoch_number    

In [None]:
model_path = './checkpoints/english_exp_2-000.pt'
model, args, start_epoch = load_model(model_path)
model = model.to(device)

In [None]:
dataset = ClipGPTFlickr8kDataset(args.data, args.prefix_length,lang= args.lang, normalize_prefix=args.normalize_prefix)
train_dataloader = DataLoader(dataset, batch_size=5, shuffle=False, drop_last=True)

dataset2 = ClipGPTFlickr8kDataset('./data/embeddings/arabic_CLIP-ViT-B-32_embeddings.pkl', args.prefix_length,lang= 'arabic', normalize_prefix=args.normalize_prefix)
tokenizer = dataset2.tokenizer


In [None]:
ietr_obj = iter(train_dataloader)
with torch.no_grad():
    for i in range(10):
        tokens, mask, prefix = next(ietr_obj)
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, dataset.prefix_length - 1: -1]
        preds = torch.argmax(logits, dim=-1).tolist()

        for pred in preds:
            print(tokenizer.decode(pred))
            print('-----------------')

In [None]:
import clip
import torch
import pickle 
import random
import numpy as np
import pandas as pd
import transformers
from tqdm import tqdm
from bleu import generate_caption
from inference_gpt import load_model
from multilingual_clip import pt_multilingual_clip
from transformers import AutoTokenizer, GPT2Tokenizer

# select random 10 images
random.seed(42)

# model_path = './checkpoints/english_exp_1-029.pt'
model_path = './checkpoints/arabic_exp_2-045.pt'

k = 10

if 'english' in model_path:
    lang = 'english'
if 'arabic' in model_path:
    lang = 'arabic'

data_path = f'./data/embeddings/{lang}_CLIP-ViT-B-32_embeddings.pkl'
with open(data_path, 'rb') as f:
    data = pickle.load(f)

image_ids = [data['captions'][i]['image_id'] for i in range(len(data['captions']))]
unique_image_ids = np.unique(image_ids, return_index=True)[1]
n = random.sample(list(unique_image_ids), k)

device = 'cuda' if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
logit_scale = clip_model.logit_scale.exp().float().to('cpu')

multilingual_clip_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-32')
multilingual_tokenizer = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-32')
multilingual_tokenizer.pad_token = multilingual_tokenizer.eos_token


pretrained_model, prefix_length = load_model(model_path)
pretrained_model = pretrained_model.to(device)
if lang == 'arabic':
    pretrained_tokenizer = AutoTokenizer.from_pretrained("akhooli/gpt2-small-arabic")
if lang == 'english':
    pretrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
pretrained_tokenizer.pad_token = pretrained_tokenizer.eos_token


image_embeddings = data['clip_embedding'][n].float()
sample_captions = [data['captions'][i]['caption'] for i in n]
image_ids = [data['captions'][i]['image_id'] for i in n]

predictions = []
for i in tqdm(range(len(image_ids))):
    image_path = f'./data/images/{image_ids[i]}'
    prediction = generate_caption(image_path, pretrained_model ,preprocess, clip_model, pretrained_tokenizer, prefix_length, device)
    predictions.append(prediction)

with torch.no_grad():
    text_embeddings = multilingual_clip_model.forward(sample_captions, multilingual_tokenizer).float()
    predicted_embeddings = multilingual_clip_model.forward(predictions, pretrained_tokenizer).float()

    image_features = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    text_features  = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    predicted_text_features  = predicted_embeddings / predicted_embeddings.norm(dim=-1, keepdim=True)


    true_similarities = []
    pred_similarities = []
    true_captions = []
    pred_captions = []
    for i in range(len(image_ids)):
        true_similarity = logit_scale * (image_features[i]* text_features[i]).sum()
        pred_similarity = logit_scale * (image_features[i]* predicted_text_features[i]).sum()
        true_similarities.append(true_similarity.item())
        pred_similarities.append(pred_similarity.item())
        true_captions.append(sample_captions[i])
        pred_captions.append(predictions[i])

    df = pd.DataFrame({
        'image_id': image_ids,
        'true_similarity': true_similarities,
        'predicted_similarity': pred_similarities,
        'true_caption': true_captions,
        'predicted_caption': pred_captions
        })
df

In [None]:
####### Run Inference On Sample Images ################
import inference_gpt
ckpt_path = './checkpoints/arabic_exp_2-045.pt'
inference_gpt.main(ckpt_path)

In [None]:
from bleu import belu_score

model_path = './checkpoints/arabic_exp_2-045.pt'
belu_score(model_path)
model_path = './checkpoints/english_exp_1-029.pt'
belu_score(model_path)