### Finetune the CLIP model on Flickr30K, CoCo dataset

In [1]:
import torch
import torch.nn as nn
from transformers import (
    CLIPProcessor, 
    CLIPModel,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)

from datasets import load_dataset
from zeroshot_retrieval import zeroshot_evaluate
from tqdm import tqdm
from icecream import ic

In [2]:
pretrained_model_path = "openai/clip-vit-base-patch16"
save_path = "saved_clip_model\\clip-vit-base-p16-finetuned_proj_512x512"


In [None]:
dataset = load_dataset(
    'nlphuji/flickr30k',
    cache_dir='cache',
    keep_in_memory=True,
)['test'] # type: ignore

train_dataset = dataset.filter(lambda x: x["split"] == "train")
test_dataset = dataset.filter(lambda x: x["split"] == "test")
val_dataset = dataset.filter(lambda x: x["split"] == "val")

In [3]:
# torch.save(train_dataset, "data/flickr30k/train_dataset.pt")
# torch.save(test_dataset, "data/flickr30k/test_dataset.pt")
# torch.save(val_dataset, "data/flickr30k/val_dataset.pt")

train_dataset = torch.load("data/flickr30k/train_dataset.pt")
val_dataset = torch.load("data/flickr30k/val_dataset.pt")
test_dataset = torch.load("data/flickr30k/test_dataset.pt")

In [3]:
model = CLIPModel.from_pretrained(pretrained_model_path, cache_dir="cache").cuda()
# model.visual_projection = nn.Linear(in_features=768, out_features=768, bias=False)
# model.text_projection = nn.Linear(in_features=512, out_features=768, bias=False)

processor = CLIPProcessor.from_pretrained(pretrained_model_path, cache_dir="cache")

In [5]:
# tmp = []
def tokenize_function(examples, max_seq_length=77):
    # First, obtain the sentences. 
    # Here, one image corresponds to five sentences, and we will combine them into one sentence
    # tmp.append(examples)
    text = [" ".join('%s' %a for a in sentence) for sentence in examples['caption']]

    image = [image.convert("RGB") for image in examples['image']]
    processed_data = processor(
        images=image, 
        text=text,
        padding='max_length',
        max_length=max_seq_length,
        truncation=True,
        return_attention_mask=True, 
        return_tensors='pt',
    )
    batch = {
        'input_ids':processed_data['input_ids'],
        # 'token_type_ids':processed_data['token_type_ids'],
        'attention_mask':processed_data['attention_mask'],
        'pixel_values':processed_data['pixel_values'],
    }

    return batch

In [None]:
# tokenlizing
tokenized_train_dataset = train_dataset.map(
        tokenize_function, 
        batched=True, 
        remove_columns=['image', 'split','caption','sentids', 'img_id', 'filename'], 
        # keep_in_memory = True,
        batch_size = 500,
    )
torch.save(tokenized_train_dataset, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_train_batched_500.pt")

In [8]:
tokenized_eval_dataset = val_dataset.map(
        tokenize_function, 
        batched=True, 
        remove_columns=['image', 'split','caption','sentids', 'img_id', 'filename'],
        # keep_in_memory = True,
        batch_size = 500,
    )

torch.save(tokenized_eval_dataset, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_eval_batched_500.pt")

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

In [None]:
tokenized_test_dataset = test_dataset.map(
        tokenize_function, 
        batched=True, 
        remove_columns=['image', 'split','caption','sentids', 'img_id', 'filename'],
        # keep_in_memory = True,
        batch_size = 500,
    )
torch.save(tokenized_test_dataset, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_test_batched_500.pt")

In [5]:
tokenized_train_dataset = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_train_batched_500.pt")
tokenized_eval_dataset = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_eval_batched_500.pt")
tokenized_test_dataset = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_test_batched_500.pt")

In [11]:
tokenized_train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'pixel_values'],
    num_rows: 29000
})

In [None]:
# # tmp = []
# # tmp1 = []
# def collate_fn_epoch(examples):  # evaluation_strategy = "epoch"
#     pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples], dim=0)
#     input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
#     attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
#     return {
#         "pixel_values": pixel_values,
#         "input_ids": input_ids,
#         "attention_mask": attention_mask,
#         "return_loss": True,
#     }

In [6]:
def collate_fn_steps(examples):  # evaluation_strategy = "steps"
    pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in examples], dim=0).squeeze(0)
    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long).squeeze(1)
    attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long).squeeze(1)
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "return_loss": True,
    }

In [None]:
# #FREEZE
# for i,j in model.named_parameters():
#     if 'prompt' not in i:
#         j.requires_grad=False

In [None]:
earlystop = EarlyStoppingCallback(early_stopping_patience=10)

training_args = TrainingArguments(
    output_dir            = save_path,
    evaluation_strategy   = "steps",
    learning_rate         = 3e-6,
    weight_decay          = 1e-6,
    save_steps            = 50,
    eval_steps            = 50,
    num_train_epochs      = 20,
    save_strategy         = "steps",
    remove_unused_columns = False,
    warmup_steps          = 50,
    per_device_train_batch_size = 128,
    per_device_eval_batch_size  = 128,
    lr_scheduler_type     = "cosine",
    label_smoothing_factor= 0.05,
    # auto_find_batch_size  = True,
    metric_for_best_model = "eval_loss",  # for earlystop
    # metric_for_best_model = "eval_acc",  # for earlystop
    load_best_model_at_end = True,
    save_total_limit = 1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    # compute_metrics=compute_metrics,
    data_collator=collate_fn_steps,
    callbacks=[earlystop],
)

train_result = trainer.train()
results = zeroshot_evaluate(model=model,dataloader=test_dataset,processor=processor,max_seq_length=77,recall_k_list=[1,5,10],device=torch.device("cuda")) # type: ignore

ic(results)

In [8]:
trainer.save_model()

In [9]:
# For KeyboardInterrupt usage
test_dataset = torch.load("data/flickr30k/test_dataset.pt")
zeroshot_evaluate(model=model,dataloader=test_dataset,processor=processor,max_seq_length=77,recall_k_list=[1,5,10],device=torch.device("cuda")) 

0it [00:00, ?it/s]

1000it [01:03, 15.68it/s]


{'image_retrieval_recall@1': 0.7523999810218811,
 'text_retrieval_recall@1': 0.8970000147819519,
 'image_retrieval_recall@5': 0.9343999624252319,
 'text_retrieval_recall@5': 0.9880000352859497,
 'image_retrieval_recall@10': 0.9666000008583069,
 'text_retrieval_recall@10': 0.9930000305175781}

### Zero-shot Cross-Modal Retrieval on Flickr30k

In [None]:
from zeroshot_retrieval import zeroshot_evaluate
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from transformers import (
    CLIPProcessor, 
    CLIPModel,
    CLIPConfig,
)

test_dataset = torch.load("data/flickr30k/test_dataset.pt")
pretrained_model_path = "openai/clip-vit-base-patch32"
save_path = "saved_clip_model\\clip-vit-base-p32-finetuned_proj_768x768"

config = CLIPConfig().from_pretrained(save_path)

# if 512x512
# model = CLIPModel(config).from_pretrained(save_path, cache_dir="cache").cuda()

# if 768x768
model = CLIPModel(config).cuda()
model.visual_projection = torch.nn.Linear(in_features=768, out_features=768, bias=False).cuda()
model.text_projection = torch.nn.Linear(in_features=512, out_features=768, bias=False).cuda()
pretrained_dict = torch.load(save_path+'\\pytorch_model.bin', map_location=device)
model.load_state_dict(pretrained_dict)

processor = CLIPProcessor.from_pretrained(pretrained_model_path, cache_dir="cache")


results = zeroshot_evaluate(model=model,dataloader=test_dataset,processor=processor,max_seq_length=77,recall_k_list=[1,5,10],device=torch.device("cuda")) # type: ignore
print(results)

### Build Feature Pool

In [1]:
import torch
from transformers import (
    CLIPProcessor, 
    CLIPModel,
    AutoConfig,
)

from datasets import load_dataset

# pretrained_model_path = "openai/clip-vit-base-patch32"
# save_path = "output\\clip-finetuned_proj_768x768"
pretrained_model_path = "openai/clip-vit-base-patch16"
save_path = "saved_clip_model\\clip-vit-base-p16-finetuned_proj_512x512"

# dataset = load_dataset('nlphuji/flickr30k',cache_dir='cache',keep_in_memory=False,)['test'] 
# train_dataset = dataset.filter(lambda x: x["split"] == "train")
train_dataset = torch.load("data/flickr30k/train_dataset.pt")

processor = CLIPProcessor.from_pretrained(pretrained_model_path, cache_dir="cache")

config = AutoConfig.from_pretrained(save_path)
model = CLIPModel(config).from_pretrained(save_path, cache_dir="cache").cuda()

# model = CLIPModel(config).cuda()
# model.visual_projection = torch.nn.Linear(in_features=768, out_features=768, bias=False)
# model.text_projection = torch.nn.Linear(in_features=512, out_features=768, bias=False)
# pretrained_dict = torch.load(save_path+"/pytorch_model.bin")
# model_dict = model.state_dict()
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)


def tokenize_func(examples, max_seq_length=77):
    # tmp.append(examples)
    text = [" ".join('%s' %a for a in sentence) for sentence in examples['caption']]

    image = [image.convert("RGB") for image in examples['image']]
    processed_data = processor(
        images=image, 
        text=text,
        padding='max_length',
        max_length=max_seq_length,
        truncation=True,
        return_attention_mask=True, 
        return_tensors='pt',
    )
    return processed_data


In [None]:
# tokenlizing
tokenized_train_dataset = train_dataset.map(
        tokenize_func, 
        batched=True, 
        remove_columns=['image', 'split','caption','sentids', 'img_id', 'filename'], 
        # keep_in_memory = True,
        batch_size = 500,
    )

In [3]:
torch.save(tokenized_train_dataset, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_train_dataset.pt")

In [4]:
# tokenized_train_dataset = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/tokenized_train_dataset.pt")
pixel_values = torch.stack([torch.tensor(example["pixel_values"]) for example in tokenized_train_dataset], dim=0).squeeze(0)

In [5]:
input_ids = torch.tensor([example["input_ids"] for example in tokenized_train_dataset], dtype=torch.long).squeeze(1)

In [6]:
attention_mask = torch.tensor([example["attention_mask"] for example in tokenized_train_dataset], dtype=torch.long).squeeze(1)

In [7]:
torch.save(pixel_values, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/pixel_values.pt")
torch.save(input_ids, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/input_ids.pt")
torch.save(attention_mask, "tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/attention_mask.pt")

In [None]:
pixel_values = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/pixel_values.pt")
input_ids = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/input_ids.pt")
attention_mask = torch.load("tokenized_dataset/flickr30k_maxseqlen_77_512x512_p16/attention_mask.pt")

In [None]:
pixel_values.shape

torch.Size([29000, 3, 224, 224])

In [5]:
input_ids.shape

torch.Size([29000, 77])

In [9]:
from tqdm import tqdm

In [10]:
batch_images_emb_list = []
batch_texts_emb_list = []
model.cuda()
model.eval()
with torch.no_grad():
    ret_dict = {}
    for pvalue, ids, atm in tqdm(zip(pixel_values, input_ids, attention_mask)):
        ret_dict = model(pixel_values=pvalue.unsqueeze(0).cuda(), 
                        input_ids=ids.unsqueeze(0).cuda(),
                        attention_mask=atm.unsqueeze(0).cuda(),
                        output_hidden_states=True,
                        )
        batch_images_emb_list.append(ret_dict['image_embeds'])    # [1,512] 
        batch_texts_emb_list.append(ret_dict['text_embeds'])      # [1,512] 

29000it [09:44, 49.63it/s]


In [47]:
len(batch_texts_emb_list)

29000

In [12]:
all_flickr30k_trainset_img_features = torch.stack(batch_images_emb_list, dim=0).squeeze(1)
all_flickr30k_trainset_text_features = torch.stack(batch_texts_emb_list, dim=0).squeeze(1)

In [17]:
all_flickr30k_trainset_img_features.shape

torch.Size([29000, 768])

In [7]:
# torch.save(all_flickr30k_testset_img_features, "data/flickr30k/all_flickr30k_testset_img_features.pt")
# torch.save(all_flickr30k_testset_text_features, "data/flickr30k/all_flickr30k_testset_text_features.pt")

In [54]:
# torch.save(all_flickr30k_trainset_img_features, "data/flickr30k/all_flickr30k_trainset_img_features.pt")
# torch.save(all_flickr30k_trainset_text_features, "data/flickr30k/all_flickr30k_trainset_text_features.pt")

In [13]:
torch.save(all_flickr30k_trainset_img_features, "data/flickr30k/flickr30k_trainset_img_features_512x512_p16.pt")
torch.save(all_flickr30k_trainset_text_features, "data/flickr30k/flickr30k_trainset_text_features_512x512_p16.pt")