In [None]:
# PARSE_COCO
# the parse file is used to extract clip features from

import torch
import skimage.io as io
import clip
from PIL import Image
import pickle
import json
import os
from tqdm import tqdm
import argparse

def main(clip_mode_type: str):
  device = torch.device('cuda:0')
  clip_model_name = clip_mode_type.replace('/','_')
  out_path = f"/kaggle/working/data/open-i/oscar_split_{clip_model_name}_train.pkl"
  clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
  with open('/kaggle/working/open-i-dataset/Open-I Dataset/Modified-Captions/Train.jsonl', 'r') as f:
    data = json.load(f)
  print("%0d captions loaded from json " % len(data))
  all_embeddings = []
  all_captions = []

  for i in tqdm(range(len(data))):
    d = data[i]
    img_id = d["image_id"]

    filename = f"/kaggle/input/open-i-dataset/Open-I Dataset/images/{img_id}.jpg"
    if not os.path.isfile(filename):
      filename = f"/kaggle/input/open-i-dataset/Open-I Dataset/images/{img_id}.jpg"
    image = io.imread(filename)
    image = preprocess(Image.fromarray(image).unsqueeze(0).to(device)
    with torch.no_grad():
      prefix = clip_model.encode_image(image).cpu()
    d["clip_embedding"] = i
    all_embeddings.append(prefix)
    all_captions.append(d['text'])
    if (i+1) % 10000 == 0:
      with open(out_path, 'wb') as f:
          pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
  with open(out_path, 'wb') as f:
      pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
  print('Done')
  print("%0d embeddings saved " % len(all_embeddings)))
  return 0


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
  args = parser.parse_args()
  exit(main(args.clip_model_type))

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_scheduler_with_warmup
from tqdm import tqdm
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union

class MappingType(Enum):
  MLP = 'mlp'
  Transformer = 'transformer'

class ClipCocoDataset(Dataset):

  def __len__(self) -> int:
    return len(self.captions_tokens)

  def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2", normalize_prefix=False):

    self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
    self.prefix_length = prefix_length
    self.normalize_prefix = normalize_prefix

    try:
      with open(data_path, 'rb') as f:
        all_data = pickle.load(f)

      print("Data size is %0d" % len(all_data['clip_embedding']))
      sys.stdout.flush()
      self.prefixes = all_data['clip_embedding']
      captions_raw = all_data['text']
      self.image_ids = [captions['image_id'] for caption in captions_raw]
      self.captions = [captions['caption'] for caption in captions_raw]
      if os.path.isfile(f"{data_pathp[:-4]}_tokens.pkl"):
        with open(f"{data_path[:-4]}_tokens.pkl", "rb") as f:
          self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
      else:
        self.captions_tokens = []
        self.caption2embedding = []
        max_seq_length = 0
        for caption in captions_raw:
          self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
          self.caption2embedding.append(caption['clip_embedding'])
          max_seq_len = max(max_seq_leng, self.captions_tokens[-1].shape[0])
        with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
          pickle.dump([self.captions_tokensm self.caption2embedding, max_seq_len], f)
        all_len = torch.tensor([len(self.captions_tokens[i] for i in range(len(self)))]).float()
        self.max_seq_len = min(int(all_len.mean()) + all_len.std() * 10), int(all_len.max())

        # Main variables to be saved in the new pkl file is:
        # captions_tokens: list of caption tokens saved as torch tensors,
        # caption2embedding: serial_no. of respective image-caption pair,
        # max_seq_len: maximum sequence length of the dataset.


SyntaxError: invalid syntax. Perhaps you forgot a comma? (<ipython-input-1-19c4e115feb9>, line 52)

# Training on Open-I

In [None]:
from google.colab import drive
drive.mount('content')

Mounted at content


In [None]:
!unzip '/content/content/MyDrive/Dataset/Open-I Dataset.zip' -d '/content/open-i'

Archive:  /content/content/MyDrive/Dataset/Open-I Dataset.zip
   creating: /content/open-i/Open-I Dataset/Captions/
  inflating: /content/open-i/Open-I Dataset/Captions/Test.jsonl  
  inflating: /content/open-i/Open-I Dataset/Captions/Train.jsonl  
  inflating: /content/open-i/Open-I Dataset/Captions/Valid.jsonl  
   creating: /content/open-i/Open-I Dataset/images/
  inflating: /content/open-i/Open-I Dataset/images/1.jpg  
  inflating: /content/open-i/Open-I Dataset/images/10.jpg  
  inflating: /content/open-i/Open-I Dataset/images/100.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1001.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1002.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1003.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1004.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1005.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1006.jpg  
  inflating: /content/open-i/Open-I Dataset/images/1008.jpg  
  inflating: /cont

### **Viewing The Keys of the Dataset**

In [None]:
import json

train_file_path = "/content/open-i/Open-I Dataset/Captions/Train.jsonl"
with open(train_file_path, 'r', encoding='utf-8-sig') as file:
    data = file.read()

result = [json.loads(json_line) for json_line in data.splitlines()]

print(len(result)) # important
#print(result[1])
i = 0
d = result[i]

img_id = d["id"]
print(f"Real URL: {d['img']}")
print(f"'{img_id}'")
filename = f"/content/open-i/Open-I Dataset/images/{int(img_id)}.jpg"
print(f"Filename: {filename}")
print(f"{result[i]}")

2483
Real URL: data/preprocessed/openi/train/3603.jpg
'3603'
Filename: /content/open-i/Open-I Dataset/images/3603.jpg
{'id': '3603', 'label': "'No Finding'", 'text': 'Cardiac and mediastinal contours are within normal limits. The lungs are clear. Bony structures are intact.', 'img': 'data/preprocessed/openi/train/3603.jpg'}


### **Modifying the Image Paths**

In [None]:
import json
import os

train_file_path = "/content/open-i/Open-I Dataset/Captions/Test.jsonl"
new_image_path = "/content/open-i/Open-I Dataset/images/"
with open(train_file_path, 'r', encoding='utf-8-sig') as file:
    jsonl_content = file.read()

result = [json.loads(jline) for jline in jsonl_content.splitlines()]

modified_lines = []

for entry in result:
    index = entry['id']
    #print(f"Index Found:{index}")
    entry['img'] = f"{new_image_path}/{index}.jpg"
    #print(f"Modified Entry:{entry}")
    # Append the modified entry to the list
    modified_lines.append(json.dumps(entry))

i = 0
for i in range(5):
    print(modified_lines[i])

# Output path
modified_output = "/content/open-i/Open-I Dataset/Modified-Captions"

# Ensure the folder exists or is created
os.makedirs(modified_output, exist_ok=True)

# Write into the folder
with open(f"{modified_output}/Test.jsonl", 'w') as output_file:
    output_file.write('\n'.join(modified_lines))

{"id": "1420", "label": "''", "text": "No focal alveolar consolidation, no definite pleural effusion seen, left hilar calcifications and dense nodule in the left lung suggest a previous granulomatous process. Considering differences in technical factors XXXX stable cardiomediastinal silhouette with normal heart size, bronchovascular crowding without typical findings of pulmonary edema.", "img": "/content/open-i/Open-I Dataset/images//1420.jpg"}
{"id": "2944", "label": "'Enlarged Cardiomediastinum', 'Lung Lesion', 'Lung Opacity'", "text": "There is a large masslike opacity in the right lung base which may represent a lung cancer. Additional evaluation XXXX advised. Right hilum appears prominent and may contain some enlarged lymph XXXX. Some calcified granulomas are seen with within the right lung. Unremarkable mediastinal contour. No effusions.", "img": "/content/open-i/Open-I Dataset/images//2944.jpg"}
{"id": "3371", "label": "''", "text": "Heart size within normal limits. Negative for

### **Parsing Open-I Dataset to create "Image-Caption" Pair**

In [None]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-vgnde48r
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-vgnde48r
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=f6bf09a7620af815a458153893bbfa05578a4de384767d8938da963e8298aced
  Stored in directory: /tmp/pip-ephem-wheel-cache-g2ognku1/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac13b112deb897d5b50f5ad9a37e4
Successfully built clip
Inst

In [None]:
%%writefile parse_openi_yash.py
import torch
import skimage.io as io
import clip
from PIL import Image
import pickle
import json
import os
from tqdm import tqdm
import argparse


def main(clip_model_type: str):
    device = torch.device('cuda:0')
    clip_model_name = clip_model_type.replace('/', '_')
    out_path = f"/content/open-i/oscar_split_{clip_model_name}_train.pkl"

    # Yash 07/12/23
    out_folder = "/content/open-i/"
    train_file_path = "/content/open-i/Open-I Dataset/Modified-Captions/Train.jsonl"
    os.makedirs(out_folder, exist_ok=True)
    print("\n\nMatter of output path resolved.\n\n")

    clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
    print(f"{type(clip_model)} and {type(preprocess)}")
    with open(train_file_path, 'r', encoding='utf-8-sig') as file:
        data = file.read()
    result = [json.loads(json_line) for json_line in data.splitlines()]
    print("\n%0d captions loaded from json \n" % len(result))

    all_embeddings = []
    all_captions = []

    i = 0
    for i in tqdm(range(len(result))):
        d = result[i]
        img_id: int = int(d["id"])
        caption: str = d['text']
        filename = f"/content/open-i/Open-I Dataset/images/{img_id}.jpg"

        # Image Preprocessing
        image = io.imread(filename)
        image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)

        # Image Encoding
        with torch.no_grad():
            prefix = clip_model.encode_image(image).cpu()
        d["clip_embedding"] = i

        # Combining all image embeddings
        all_embeddings.append(prefix)

        # Combining all caption
        all_captions.append(d)



        # Periodic saving of embeddings at every 10k iterations for better memory management.
        if (i + 1) % 10000 == 0:
            with open(out_path, 'wb') as f:
                pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)

    with open(out_path, 'wb') as f:
#         print(f"id data type: {type(all_ids)}")
#         print(f"all_img_ids data type: {type(all_img_ids[i])}")
#         print(f"clip_embedding data type: {type(all_img_ids[i])}")
#         print(f"captions data type: {type(all_img_ids[i])}")



        pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
#     dict = {"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}
#     print(f"An example of the new file: {d}")
#     print("Example of the file just saved:")
#     print(f"\nClip_Embedding: {dict.get('clip_embedding',[])}, Captions: {dict.get('captions',[])}\n")
#    print(f"\nSaved Dictionary: {dict[0:5]}\n")

    print('Done')
    print("%0d embeddings saved " % len(all_embeddings))
    return 0


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
  args = parser.parse_args()
  exit(main(args.clip_model_type))

Writing parse_openi_yash.py


In [None]:
!python parse_openi_yash.py --clip_model_type ViT-B/32



Matter of output path resolved.


100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 66.9MiB/s]
<class 'clip.model.CLIP'> and <class 'torchvision.transforms.transforms.Compose'>

2483 captions loaded from json 

100% 2483/2483 [00:30<00:00, 80.62it/s]
Done
2483 embeddings saved 


### Checking the dimensions of the tensors.

In [None]:
data_path = '/content/open-i/oscar_split_ViT-B_32_train.pkl'
with open(data_path,'rb') as f:
  all_data = pickle.load(f)
print("Data size is %0d" % len(all_data['clip_embedding']))
print(f"Clip Embedding size = {all_data['clip_embedding'].shape}")


from transformers import GPT2Tokenizer
import sys

# For token dimension
def get_token_dim(gpt2_type : str = 'gpt2', prefix_length : int = 250, normalize_prefix = False):

        tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        # Yashfinul Haque - 30/12/23 - Wrapped the extracting of the data in a try exception block.

        with open(data_path, 'rb') as f:
            all_data = pickle.load(f)

        #print(f"Data after loading:{all_data}")
        print("Data size is %0d" % len(all_data['clip_embedding']))
        sys.stdout.flush()


        prefixes = all_data['clip_embedding']
        captions_raw = all_data['captions']
        image_ids = [caption["id"] for caption in captions_raw]
        captions = [caption['text'] for caption in captions_raw]
        print(f"\nPrefix Type: {type(prefixes)}\n")
        print(f"\nCaptions Type: {type(captions_raw)}\n")
        print(f"\nPrefix: {prefixes.shape}\n")


        # Yashfinul Haque 06/12/23
        os.makedirs(os.path.dirname(f"{data_path[:-4]}_tokens.pkl"), exist_ok=True)
        if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
            with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
                captions_tokens, caption2embedding, max_seq_len = pickle.load(f)
        else:
            captions_tokens = []
            caption2embedding = []

        # Original
            max_seq_len = 0
            for caption in captions_raw:
                captions_tokens.append(torch.tensor(tokenizer.encode(caption['text']), dtype=torch.int64))
                caption2embedding.append(caption["clip_embedding"])
                max_seq_len = max(max_seq_len, captions_tokens[-1].shape[0])

            # max_seq_len = max_seq_len


            with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
                pickle.dump([captions_tokens, caption2embedding, max_seq_len], f)
        all_len = torch.tensor([len(captions_tokens[i]) for i in range(len(captions_tokens))]).float()
        max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
        return all_len, max_seq_len

all_len, max_seq_len = get_token_dim()
print(f"\nShape of the caption token tensor: {all_len.shape}\n")
print(f"\nShape of the longest caption token tensor: {max_seq_len}\n")


NameError: name 'pickle' is not defined

### **Train Open-I Dataset**

In [None]:
!pip install GPUtil

Collecting GPUtil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: GPUtil
  Building wheel for GPUtil (setup.py) ... [?25l[?25hdone
  Created wheel for GPUtil: filename=GPUtil-1.4.0-py3-none-any.whl size=7393 sha256=f1dbf5829da81ac082a738e4da95c9a528da353c0f53195df32e2c183d764600
  Stored in directory: /root/.cache/pip/wheels/a9/8a/bd/81082387151853ab8b6b3ef33426e98f5cbfebc3c397a9d4d0
Successfully built GPUtil
Installing collected packages: GPUtil
Successfully installed GPUtil-1.4.0


In [None]:
%%writefile train_yash.py
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union

# Yashfinul Haque 06/12/23
import logging

from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()


class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'


class ClipCocoDataset(Dataset):

    def __len__(self) -> int:
        return len(self.captions_tokens)

    def pad_tokens(self, item: int):
        tokens = self.captions_tokens[item]
        # print(f"\nTokens: {tokens}\n")
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.captions_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask

        logging.debug(f"This is a debug message in 'pad_tokens'- function: {tokens},{mask}")

        return tokens, mask

    def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
        tokens, mask = self.pad_tokens(item)
        prefix = self.prefixes[self.caption2embedding[item]]
        if self.normalize_prefix:
            prefix = prefix.float()
            prefix = prefix / prefix.norm(2, -1)

        logging.debug(f"This is a debug message in 'get_item'- function: {tokens},{mask}, {prefix}")
        return tokens, mask, prefix

    def __init__(self, data_path: str,  prefix_length: int, gpt2_type: str = "gpt2",
                 normalize_prefix=False):
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix
        print(f"\nData Path: {data_path}\n")
        # Yashfinul Haque - 30/12/23 - Wrapped the extracting of the data in a try exception block.

        with open(data_path, 'rb') as f:
            all_data = pickle.load(f)

        #print(f"Data after loading:{all_data}")
        print("Data size is %0d" % len(all_data['clip_embedding']))
        sys.stdout.flush()


        self.prefixes = all_data['clip_embedding']
        captions_raw = all_data['captions']
        self.image_ids = [caption["id"] for caption in captions_raw]
        self.captions = [caption['text'] for caption in captions_raw]
        print(f"\nPrefix Type: {type(self.prefixes)}\n")
        print(f"\nCaptions Type: {type(captions_raw)}\n")
        print(f"\nPrefix: {self.prefixes}\n")


        # Yashfinul Haque 06/12/23
        os.makedirs(os.path.dirname(f"{data_path[:-4]}_tokens.pkl"), exist_ok=True)
        if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
            with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
                self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
        else:
            self.captions_tokens = []
            self.caption2embedding = []

        # Original
            max_seq_len = 0
            for caption in captions_raw:
                self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['text']), dtype=torch.int64))
                self.caption2embedding.append(caption["clip_embedding"])
                max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])

            # self.max_seq_len = max_seq_len


            with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
                pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
        all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
        self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))




class MLP(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class MlpTransformer(nn.Module):
    def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
        super().__init__()
        out_d = out_d if out_d is not None else in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.act = act
        self.fc2 = nn.Linear(h_dim, out_d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):

    def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim_self // num_heads
        self.scale = head_dim ** -0.5
        self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
        self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
        self.project = nn.Linear(dim_self, dim_self)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y=None, mask=None):
        y = y if y is not None else x
        b, n, c = x.shape
        _, m, d = y.shape
        # b n h dh
        queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
        # b m 2 h dh
        keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
        keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
        attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
        attention = attention.softmax(dim=2)
        out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
        out = self.project(out)
        return out, attention


class TransformerLayer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)


class Transformer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec: # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)


class TransformerMapper(nn.Module):

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        self.transformer = Transformer(dim_embedding, 8, num_layers)
        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)


class ClipCaptionModel(nn.Module):

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if mapping_type == MappingType.MLP:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                     self.gpt_embedding_size * prefix_length))
        else:
            self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
                                                                     clip_length, num_layers)


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self


def save_config(args: argparse.Namespace):
    config = {}
    for key, item in args._get_kwargs():
        config[key] = item
    out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
    with open(out_path, 'w') as outfile:
        json.dump(config, outfile)


def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
    with open(config_path) as f:
        config = json.load(f)
    parser = argparse.ArgumentParser()
    parser.set_defaults(**config)
    args = parser.parse_args()
    if type(epoch_or_latest) is int:
        epoch_or_latest = f"-{epoch_or_latest:03d}"
    model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
    if args.only_prefix:
        model = ClipCaptionPrefix(args.prefix_length)
    else:
        model = ClipCaptionModel(args.prefix_length)
    if os.path.isfile(model_path):
        print(f"loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    else:
        print(f"{model_path} is not exist")
    return model, parser


def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
          lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    device = torch.device('cuda:0')
    batch_size = args.bs
    epochs = args.epochs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )
    # save_config(args)
    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            progress.update()
            if (idx + 1) % 10000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
        progress.close()
        if epoch % args.save_every == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )
    return model


def main():
    # logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
    # logging.getLogger().setLevel(logging.DEBUG)
    free_gpu_cache()
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default="/content/open-i/oscar_split_ViT-B_32_train.pkl")
    parser.add_argument('--out_dir', default='/content/open-i/checkpoints')
    parser.add_argument('--prefix', default='coco_prefix', help='prefix for saved filenames')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--save_every', type=int, default=1)
    parser.add_argument('--prefix_length', type=int, default=10)
    parser.add_argument('--prefix_length_clip', type=int, default=10)
    parser.add_argument('--bs', type=int, default=40)
    parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
    parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
    parser.add_argument('--num_layers', type=int, default=8)
    parser.add_argument('--is_rn', dest='is_rn', action='store_true')
    parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
    args = parser.parse_args()
    prefix_length = args.prefix_length
    dataset = ClipCocoDataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix)
    prefix_dim = 640 if args.is_rn else 512
    args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]
    if args.only_prefix:
        model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
                                  num_layers=args.num_layers, mapping_type=args.mapping_type)
        print("Train only prefix")
    else:
        model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
                                  num_layers=args.num_layers, mapping_type=args.mapping_type)
        print("Train both prefix and GPT")
        sys.stdout.flush()
    train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix)


if __name__ == '__main__':
    main()

Writing train_yash.py


In [None]:
!python train_yash.py --out_dir "/content/openi_train/" --prefix_length 250 --prefix_length_clip 250 --mapping_type "transformer" --num_layers 8 --bs 4

Initial GPU Usage
| ID | GPU | MEM |
------------------
|  0 |  0% |  0% |
GPU Usage after emptying the cache
| ID | GPU | MEM |
------------------
|  0 |  3% |  1% |
vocab.json: 100% 1.04M/1.04M [00:00<00:00, 16.5MB/s]
merges.txt: 100% 456k/456k [00:00<00:00, 2.50MB/s]
tokenizer.json: 100% 1.36M/1.36M [00:00<00:00, 22.8MB/s]
config.json: 100% 665/665 [00:00<00:00, 3.63MB/s]

Data Path: /content/open-i/oscar_split_ViT-B_32_train.pkl

Data size is 2483

Prefix Type: <class 'torch.Tensor'>


Captions Type: <class 'list'>


Prefix: tensor([[ 0.1871, -0.0982, -0.2462,  ...,  0.3254, -0.2881,  0.3137],
        [ 0.1575, -0.0314, -0.3894,  ...,  0.2993, -0.3894,  0.1566],
        [ 0.1083, -0.0636, -0.2676,  ...,  0.3933, -0.3679,  0.3047],
        ...,
        [ 0.2118, -0.0068, -0.3044,  ...,  0.3889, -0.5000,  0.2283],
        [ 0.1927, -0.0267, -0.0917,  ...,  0.5625, -0.3472,  0.2290],
        [ 0.0176, -0.0382, -0.1676,  ...,  0.2073, -0.0815, -0.0098]],
       dtype=torch.float16)

mo

# Prediction of ClipCap

In [None]:
# Prediction interface for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/python.md[ ]


import clip
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Optional
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AdamW,
    get_linear_schedule_with_warmup,
)
import skimage.io as io
import PIL.Image

import cog

# import torch

N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]

WEIGHTS_PATHS = {
    "coco": "coco_weights.pt",
    "conceptual-captions": "conceptual_weights.pt",
}

D = torch.device
CPU = torch.device("cpu")


class Predictor(cog.Predictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        self.device = torch.device("cuda")
        self.clip_model, self.preprocess = clip.load(
            "ViT-B/32", device=self.device, jit=False
        )
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

        self.models = {}
        self.prefix_length = 10
        for key, weights_path in WEIGHTS_PATHS.items():
            model = ClipCaptionModel(self.prefix_length)
            model.load_state_dict(torch.load(weights_path, map_location=CPU))
            model = model.eval()
            model = model.to(self.device)
            self.models[key] = model

    @cog.input("image", type=cog.Path, help="Input image")
    @cog.input(
        "model",
        type=str,
        options=WEIGHTS_PATHS.keys(),
        default="coco",
        help="Model to use",
    )
    @cog.input(
        "use_beam_search",
        type=bool,
        default=False,
        help="Whether to apply beam search to generate the output text",
    )
    def predict(self, image, model, use_beam_search):
        """Run a single prediction on the model"""
        image = io.imread(image)
        model = self.models[model]
        pil_image = PIL.Image.fromarray(image)
        image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            prefix = self.clip_model.encode_image(image).to(
                self.device, dtype=torch.float32
            )
            prefix_embed = model.clip_project(prefix).reshape(1, self.prefix_length, -1)
        if use_beam_search:
            return generate_beam(model, self.tokenizer, embed=prefix_embed)[0]
        else:
            return generate2(model, self.tokenizer, embed=prefix_embed)


class MLP(nn.Module):
    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class ClipCaptionModel(nn.Module):

    # @functools.lru_cache #FIXME
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(
            batch_size, self.prefix_length, dtype=torch.int64, device=device
        )

    def forward(
        self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
    ):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(
            -1, self.prefix_length, self.gpt_embedding_size
        )
        # print(embedding_text.size()) #torch.Size([5, 67, 768])
        # print(prefix_projections.size()) #torch.Size([5, 1, 768])
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(
                prefix_size, self.gpt_embedding_size * prefix_length
            )
        else:
            self.clip_project = MLP(
                (
                    prefix_size,
                    (self.gpt_embedding_size * prefix_length) // 2,
                    self.gpt_embedding_size * prefix_length,
                )
            )


class ClipCaptionPrefix(ClipCaptionModel):
    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self


def generate_beam(
    model,
    tokenizer,
    beam_size: int = 5,
    prompt=None,
    embed=None,
    entry_length=67,
    temperature=1.0,
    stop_token: str = ".",
):

    model.eval()
    stop_token_index = tokenizer.encode(stop_token)[0]
    tokens = None
    scores = None
    device = next(model.parameters()).device
    seq_lengths = torch.ones(beam_size, device=device)
    is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
    with torch.no_grad():
        if embed is not None:
            generated = embed
        else:
            if tokens is None:
                tokens = torch.tensor(tokenizer.encode(prompt))
                tokens = tokens.unsqueeze(0).to(device)
                generated = model.gpt.transformer.wte(tokens)
        for i in range(entry_length):
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            logits = logits.softmax(-1).log()
            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                generated = generated.expand(beam_size, *generated.shape[1:])
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
                    beam_size, -1
                )
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
                generated.shape[0], 1, -1
            )
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [
        tokenizer.decode(output[: int(length)])
        for output, length in zip(output_list, seq_lengths)
    ]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts


def generate2(
    model,
    tokenizer,
    tokens=None,
    prompt=None,
    embed=None,
    entry_count=1,
    entry_length=67,  # maximum number of words
    top_p=0.8,
    temperature=1.0,
    stop_token: str = ".",
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in range(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.gpt.transformer.wte(tokens)

            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    nnf.softmax(sorted_logits, dim=-1), dim=-1
                )
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

In [None]:


def compute_ranks(args, results, labels, idx_list):
  labels = np.array(labels)
  similarities = np.array(results[i] for i in range(len(idx_list)))
  idx_list = np.reshape(idx_list, [-1, num_txt_per_img])
  # Since we are essentially trying to generate captions, hence
  i2t_ranks, Aligned_list = [], []
  for lab, sim, idx in zip(labels, similarities, idx_list):
    idns = np.argsort(sim)[::-1] # descending 4,3,2,1
    rank = num_txt_per_img
    for r, ind in enumerate(inds):
      if lab[ind] == 1:
        rank = r
        break
    Aligned_list.append([idx[ind], rank])
    # Since we are essentially trying to generate captions, hence
    i2t_ranks.append(rank)
    print(f'length of i2t ranks:{len(i2t_ranks)}')
  return i2t_ranks, Aligned_list
"""
Here the results are further preprocessed for its metrics
"""
def compute_recall_precision(args, results, labels, idx_list):
  labels = np.array(labels)
  similarities = np.array([results[i] for i in range(len(idx_list))])
  num_text_per_img = args.eval_len_size # POSSIBLE ARGUMENT
  labels = np.reshape(labels, [-1, num_text_per_img])
  similarities = np.reshape(similarities, [-1, num_text_per_img])

  ranks = [1,5,10]
  recall, precision = [], []
  for k in ranks:
    r_list, p_list = [], []
    for lab, sim in zip(labels, similarities):
      sorted_label = []
      inds = np.argsort(sim)[::-1] # descending 4,3,2,1
      for ind in inds:
        sorted_label.append(lab[ind])
      top = np.array(sorted_label[:k]).sum()
      bottom = np.array(sorted_label).sum()
      r = top / bottom
      p = top / k
      r_list.append(r)
      p_list.append(p)
    r_v = np.mean(np.array(r_list))
    p_v = np.meand(np.array(p_list))
    recall.append(r_v)
    precision.append(p_v)

  results = {
      'i2t_recall': {"R@1": round(recall[0], 3), "R@5": round(recall[1], 3), "R@10": round(recall[2], 3)},
      'i2t_precision': {"R@1": round(precision[0], 3), "R@5": round(precision[1], 3), "R@10": round(precision[2], 3)}
      }
  return results


def compute_mrr(ranks):
  ranks = np.array(ranks, dtype=float)
  ranks = ranks + 1
  print('ranks + 1:', ranks)
  mrr_score = np.mean(np.reciprocal(ranks))
  print(f'reciprocal_ranks:{np.reciprocal(ranks)}')
  print(f'mrr_score: {mrr_score}')
  return mrr_score

def evaluate(args, test_results, test_labels, idx_list):
  i2t_ranks, Aligned_list = computed_ranks(args, test_results, test_labels, idx_list)
  recall_precision_results = compute_recall_precision(args, test_results, test_labels, idx_list)
  rank = [ 1, 5, 10]
  eval_result = {}

  it2t_accs = [sum([_ < r for _ in i2t_ranks]) / len(i2t_ranks) for r in rank]
  eval_result = {"i2t_retrieval": {"R@1": i2t_accs[0], "R@5": i2t_accs[1], "R@10": i2t_accs[2]}}
  mrr_score = compute_mrr(i2t_ranks)

  return eval_result, Aligned_list, mrr_score, recall_precision_results


In [None]:
def test(args, model, eval_dataset):
  model.eval()
  labels = []
  results_list = []
  softmax = nn.Softmax(dim=1)
  criterion = nn.CrossEntropyLoss()
  eval_losses = []
  eval_data_iter = tqdm(enumerate(eval_dataset),
                        total=len(eval_dataset),
                        bar_format='{1_bar}{r_bar}')
  for idx, batch in eval_data_iter:
    with torch.no_grad():
      cls_tok = batch[0].to(args.device)
      input_txt = batch[1].to(args.device)
      attn_mask = batch[2].to(args.device)
      input_img = batch[3].to(args.device)
      segment = batch[4].to(args.device)
      sep_tok = batch[5].to(args.device)
      label = batch[6].tolist()
      idx = batch[7].tolist()
      logits = model(cls_tok, input_txt, attn_mask, segment, input_img, sep_tok)

      labels.extend(label)
      idx_lst.extend(idx)
      eval_loss = criterion(logits, torch.tensor(label).to(args.device))
      eval_losses.append(eval_loss.item())
      probs = softmax(logits)
      results = probs[:,1]
      result = [_.to(torch.device("cpu")) for _ in result]
      results_lst.extend(result)

  return results_lst, labels, eval_losses, idx_lst