<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw3/p2_Image_caption_large.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

In [2]:
import os
import re
import math
import timm
import json
import torch
import collections
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as tr

from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torch import nn, Tensor
from torch.utils.data import DataLoader

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Download dataset and unzip zip file.

In [None]:
!gdown 11rP6KmR5Qwjhx0rfag0b5TZGBTRuPtQR -O hw3_data.zip # 1SUiRrG6zQVtyrVSVh9hOBq5_fX-oV2Lh
!unzip /content/hw3_data.zip

#### Tokenizer ('<|endoftext|>', 50256) -> 250dim

In [34]:
class BPETokenizer:

    def __init__(self, encoder_file, vocab_file):
        with open(encoder_file, 'r', encoding='utf-8') as f:
            self.encoder = json.load(f)
        self.decoder = {v:k for k,v in self.encoder.items()}
        with open(vocab_file, 'r', encoding='utf-8') as f:
            vocab = f.read().split('\n')[1:-1]
        self.bpe_ranks = {tuple(line.split()): i for i, line in enumerate(vocab)}
        assert len(self.encoder) == 50257 and len(self.bpe_ranks) == 49999 # len(self.bpe_ranks) == 50000
        bs = list(range(33, 127)) + list(range(161, 256))
        xs = list(range(0, 33)) + list(range(127, 161))
        cs = bs[:] + [2**8 + i for i in range(len(xs))]
        self.byte_encoder = dict(zip(bs + xs, [chr(n) for n in cs]))
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}

    def encode(self, text, allowed_special=None):
        tokens = re.findall(r"""<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d| ?""" +
                            r"""\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""", text, re.UNICODE)
        def translate(token):
            if token == '<|endoftext|>':
                assert allowed_special and token in allowed_special
                return [token]
            word = tuple(''.join(self.byte_encoder[byte] for byte in token.encode('utf-8')))
            while len(word) != 1:
                pairs = set((word[i], word[i+1]) for i in range(len(word)-1))
                bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
                if bigram not in self.bpe_ranks:
                    break
                a, b = bigram
                new_word = []
                i = 0
                while i < len(word):
                    j = word.index(a, i) if a in word[i:] else len(word)
                    new_word.extend(word[i:j])
                    i = j
                    if i < len(word):
                        j = 2 if i < len(word)-1 and word[i] == a and word[i+1] == b else 1
                        new_word.append(a+b if j == 2 else word[i])
                        i += j
                word = tuple(new_word)
            return word
        return [self.encoder[_] for token in tokens for _ in translate(token)]

    def decode(self, tokens):
        tokens = [self.decoder[token] for token in tokens]
        buffer = bytearray([self.byte_decoder[c] for c in ''.join(tokens)])
        return buffer.decode('utf-8', errors='replace')

In [35]:
encoding = BPETokenizer('/content/encoder.json', '/content/vocab.bpe')
prompt = 'a kitchen with a sink and many cooking machines and a pot of food'

text_embedding_len = 250

context = encoding.encode(prompt)
context = [50256] + context + [50256]*(text_embedding_len - len(context) - 1)
# context
encoding.decode(context)

'<|endoftext|>a kitchen with a sink and many cooking machines and a pot of food<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext

#### Define function

In [36]:
def json_load(json_path: str):
    with open(json_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

In [37]:
def caption_with_id(json_path: str) -> list:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = [{'caption': row['caption'], 'image_id': row['image_id']} for row in json_data['annotations']]
    return data

In [38]:
def id2file_name(json_path: str) -> dict:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = {row['id']: row['file_name'] for row in json_data['images']}
    return data

In [39]:
encoder_joson_path = '/content/encoder.json'
vocab_bpe_path = '/content/vocab.bpe'
def collate_fn(batch, tokenizer=BPETokenizer(encoder_joson_path, vocab_bpe_path)):
    # Get the individual elements of the batch
    images = [item['img'] for item in batch]
    captions = [item['caption'] for item in batch]
    filenames = [item['filename'] for item in batch]

    # Tokenize captions
    tokenized_captions = [tokenizer.encode(caption) for caption in captions]

    # Pad the vector length into stop token to dimension 250
    text_len = 250 # text_embedding_len
    tokenized_captions_train = [
        [50256] + caption + [50256] * (text_len - len(caption) - 1) for caption in tokenized_captions
    ]
    tokenized_captions_inf = [
        caption + [50256] + [-100] * (text_len - len(caption) - 1) for caption in tokenized_captions
    ]

    # Convert tokenized captions to PyTorch tensors
    tokenized_captions_train = [torch.tensor(caption) for caption in tokenized_captions_train]
    tokenized_captions_inf = [torch.tensor(caption) for caption in tokenized_captions_inf]

    # Create a new batch with tokenized captions
    tokenized_batch = {
        'img': torch.stack(images, dim=0),
        'tokenized_captions_train': torch.stack(tokenized_captions_train, dim=0),
        'filename': filenames,
        'tokenized_captions_inf': torch.stack(tokenized_captions_inf, dim=0),
    }

    return tokenized_batch

#### Build Dataset

In [40]:
class ImgCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, json_path, transform) -> None:
        super(ImgCaptionDataset, self).__init__()
        self.img_dir = img_dir
        self.transform = transform

        # Connect caption -> image_id -> file_name
        self.caption_with_id = caption_with_id(json_path)
        self.id2file_name = id2file_name(json_path)
    def __len__(self) -> int:
        return len(self.caption_with_id)

    def __getitem__(self, idx):
        caption_id = self.caption_with_id[idx]
        file_name = self.id2file_name[caption_id['image_id']]
        img = Image.open(os.path.join(self.img_dir, file_name)).convert('RGB')
        img = self.transform(img)
        return {'img': img, 'caption': caption_id['caption'], 'filename': os.path.splitext(file_name)[0]}

In [41]:
# class ImgDataset(torch.utils.data.Dataset):
#     def __init__(self, root: str, transform) -> None:
#         self.transform = transform
#         self.img_path = [i for i in Path(root).glob("*.jpg")]

#     def __len__(self) -> int:
#         return len(self.img_path)

#     def __getitem__(self, idx):
#         img = Image.open(self.img_path[idx]).convert('RGB')
#         img = self.transform(img)
#         return img, os.path.splitext(self.img_path[idx].name)[0]

#### Build Dataloader

In [42]:
train_ds = ImgCaptionDataset(
    img_dir='/content/hw3_data/p2_data/images/train',
    json_path='/content/hw3_data/p2_data/train.json',
    transform=tr.Compose([
        tr.Resize(224),
        tr.CenterCrop(224),
        tr.ToTensor(),
        tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
)
val_ds = ImgCaptionDataset(
    img_dir='/content/hw3_data/p2_data/images/train',
    json_path='/content/hw3_data/p2_data/val.json',
    transform=tr.Compose([
        tr.Resize(224),
        tr.CenterCrop(224),
        tr.ToTensor(),
        tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
)

train_loader = DataLoader(
    train_ds,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=4,
)
val_loader = DataLoader(
    train_ds,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=4,
)

#### Config

In [43]:
class Config:

    def __init__(self, checkpoint=None):
        self.n_layer = 12
        self.n_head = 12
        self.n_embd = 768
        self.vocab_size = 50257
        self.block_size = 1024
        self.checkpoint = checkpoint

In [44]:
cfg = Config(checkpoint='/content/hw3_data/p2_data/decoder_model.bin')

#### timm's ViT encoder

In [45]:
# encoder = timm.create_model('vit_large_patch16_224_in21k', pretrained=True) # vit_base_patch16_224_in21k

In [46]:
# for batch in train_loader:
#     img = batch['img']
#     break

In [47]:
# with torch.no_grad():
#     encoder_out = encoder.forward_features(img)

In [48]:
# encoder_out.size()

#### decoder

In [49]:
class Attention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd)
        self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd)
        self.n_head = cfg.n_head
        self.n_embd = cfg.n_embd
        size = cfg.block_size
        self.register_buffer('bias', torch.tril(torch.ones(size, size)).view(1, 1, size, size))

    def forward(self, x):
        B, T, C = x.size() # batch, context, embedding
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        return self.c_proj((att @ v).transpose(1, 2).contiguous().view(B, T, C))

class CrossAttention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(cfg.n_embd, cfg.n_head, batch_first=True)

    def forward(self, query, encoder_out):
        """
        Q is the source from the decoder, K, V are the sources from the encoder.
        Q: (N, L, Eq), where L is the target embedding dim, Eq is embed_dim and batch_first=True.
        {K, V}: (N, L, E{k,v}), where L is the source embedding dim, E{k,v} is {k,v}_dim and batch_first=True.
        """
        attn_output, attn_output_weights = self.multihead_attn(query, encoder_out, encoder_out)
        return attn_output # , attn_output_weights

class Block(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.ln_1 = nn.LayerNorm(cfg.n_embd)
        self.ln_2 = nn.LayerNorm(cfg.n_embd) # add
        self.ln_3 = nn.LayerNorm(cfg.n_embd)
        self.attn = Attention(cfg)
        self.crs_attn = CrossAttention(cfg) # add
        self.mlp = nn.Sequential(collections.OrderedDict([
            ('c_fc', nn.Linear(cfg.n_embd, 4 * cfg.n_embd)),
            ('act', nn.GELU(approximate='tanh')),
            ('c_proj', nn.Linear(4 * cfg.n_embd, cfg.n_embd))
        ]))

    def forward(self, x, encoder_out) -> Tensor: # add
        x = x + self.attn(self.ln_1(x))
        x = x + self.crs_attn(self.ln_2(x), self.ln_2(encoder_out)) # add
        x = x + self.mlp(self.ln_3(x))
        return x

class Decoder(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.block_size = cfg.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(cfg.vocab_size, cfg.n_embd), # 文字投影
            wpe = nn.Embedding(cfg.block_size, cfg.n_embd), # position
            h = nn.Sequential(*[Block(cfg) for _ in range(cfg.n_layer)]), # Nx
            ln_f = nn.LayerNorm(cfg.n_embd)
        ))
        self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        # timm's ViT encoder (vit_base_patch16_224_in21k, , vit_huge_patch14_224_in21k)
        self.encoder = timm.create_model('vit_large_patch16_224_in21k', pretrained=True)
        self.linear = nn.Linear(1024, cfg.n_embd) # [16, 197, 1024]
        # load checkpoint
        if self.cfg.checkpoint is not None:
            state_dict = torch.load(self.cfg.checkpoint)
            transposed = [ '.c_attn.weight', '.c_fc.weight', '.c_proj.weight' ]
            for key, value in state_dict.items():
                if any(key.endswith(w) for w in transposed):
                    state_dict[key] = value.t()
            self.transformer.load_state_dict(state_dict, strict=False)

    def forward(self, x: Tensor, img: Tensor) -> Tensor: # add
        x = torch.narrow(x, 1, 0, min(x.size(1), self.block_size))
        pos = torch.arange(x.size()[1], dtype=torch.long, device=x.device).unsqueeze(0)
        x = self.transformer.wte(x) + self.transformer.wpe(pos)
        with torch.no_grad():
            encoder_out = self.encoder.forward_features(img)
        for block in self.transformer.h:
            x = block(x, self.linear(encoder_out))
        x = self.lm_head(self.transformer.ln_f(x)) # add
        return x

#### Decoding test

In [50]:
# decoder = Decoder(cfg)

In [51]:
# for batch in train_loader:
#     img = batch['img']
#     tokenized_captions_train = batch['tokenized_captions_train']
#     tokenized_captions_inf = batch['tokenized_captions_inf']
#     break

In [52]:
# pred = decoder(tokenized_captions_train, img)

In [53]:
# pred.size(), tokenized_captions_train.size(), img.size(), tokenized_captions_inf.size()

In [54]:
# encoding.decode(pred[0].argmax(dim=1).tolist())

In [55]:
# loss_fn = nn.CrossEntropyLoss() # ignore_index=50256

# pred = pred.reshape(-1, 50257)
# tokenized_captions_inf = tokenized_captions_inf.reshape(-1)
# loss_fn(pred, tokenized_captions_inf)

#### Training

In [56]:
# def training(dataloader, model, loss_fn, optimizer):

#     size = len(dataloader.dataset) # number of samples
#     num_batches = len(dataloader) # batches per epoch
#     epoch_loss = 0

#     model.train() # to training mode
#     for batch_i, data in enumerate(tqdm(dataloader)):
#         data['img'] = data['img'].to(device, non_blocking=True)
#         data['tokenized_captions_train'] = data['tokenized_captions_train'].to(device, non_blocking=True)
#         data['tokenized_captions_inf'] = data['tokenized_captions_inf'].to(device, non_blocking=True)

#         # zero the parameter gradients
#         optimizer.zero_grad()

#         # Compute prediction loss
#         pred = model(data['tokenized_captions_train'], data['img'])
#         # reshape to (B, C)
#         data['tokenized_captions_inf'] = data['tokenized_captions_inf'].reshape(-1)
#         pred = pred.reshape(-1, 50257)
#         loss = loss_fn(pred, data['tokenized_captions_inf']) # tokenized captions inf

#         # Optimization by gradients
#         loss.backward() # backpropagation to compute gradients
#         optimizer.step() # update model params

#         # write to logs
#         epoch_loss += loss.item() # tensor -> python value
#     return epoch_loss/num_batches

In [57]:
# def testing(dataloader, model, loss_fn):
#     size = len(dataloader.dataset) # number of samples
#     num_batches = len(dataloader) # batches per epoch

#     model.eval() # model to test mode.
#     epoch_loss, epoch_correct = 0, 0

#     # No gradient for test data
#     with torch.no_grad():
#         for batch_i, data in enumerate(tqdm(dataloader)):
#             data['img'] = data['img'].to(device, non_blocking=True)
#             data['tokenized_captions'] = data['tokenized_captions'].to(device, non_blocking=True)

#             # Compute prediction loss
#             pred = model(data['tokenized_captions'], data['img'])
#             # loss = loss_fn(pred, x)
#             epoch_batch_size = pred.size()[0]
#             loss = [loss_fn(pred[i], data['tokenized_captions'][i]) for i in range(epoch_batch_size)]
#             loss = sum(loss)

#             # write to logs
#             epoch_loss += loss.item()
#             # (B, 250, Class)
#             pred_correct = [(pred[i].argmax(dim=1) == data['tokenized_captions'][i]).sum().item() for i in range(epoch_batch_size)]
#             epoch_correct += sum(pred_correct)

#     return epoch_loss/num_batches, epoch_correct/size

##### Freeze parameters

In [58]:
# model = Decoder(cfg).to(device)

# # Freeze parameters
# for name, param in model.named_parameters():
#     param.requires_grad=False
#     # print(f"{name}: {param.requires_grad}")
# # Unfreeze some parameters
# for i in range(12):
#     model.transformer.h[i].ln_2.weight.requires_grad = True
#     model.transformer.h[i].ln_2.bias.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.in_proj_weight.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.in_proj_bias.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.out_proj.weight.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.out_proj.bias.requires_grad = True
# model.linear.weight.requires_grad = True
# model.linear.bias.requires_grad = True

# trainable_weights = [name for name, param in model.named_parameters() if param.requires_grad == True]
# # list for True
# # for name, param in model.named_parameters():
# #     print(f"{name}: {param.requires_grad}")

In [59]:
# EPOCHS = 7
# loss_fn = nn.CrossEntropyLoss() # ignore_index=50256
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# # logs
# logs = {
#     'train_loss': []
# }

# for epoch in tqdm(range(EPOCHS)):
#     train_loss = training(train_loader, model, loss_fn, optimizer)

#     print(f'EPOCH: {epoch:04d} \train_loss: {train_loss:.4f}')

#     logs['train_loss'].append(train_loss)

#     # Save model
#     save_weights = {k: v for k, v in model.state_dict().items() if k in trainable_weights}
#     torch.save(save_weights, f'/content/drive/MyDrive/NTU_DLCV/Hw3/p2_ckpt/trainable_weights_epoch{epoch}_{train_loss:.4f}.pth')
#     print('---------- Model Save ----------')

#### Check the model params less than 35M

In [60]:
!gdown 1-obsrlcsth-FcJgr1SQ1QNJAq_yxBh9Z -O trainable_weights

Downloading...
From: https://drive.google.com/uc?id=1-obsrlcsth-FcJgr1SQ1QNJAq_yxBh9Z
To: /content/trainable_weights
100% 117M/117M [00:00<00:00, 191MB/s] 


In [61]:
model = Decoder(cfg)
# Freeze parameters
for name, param in model.named_parameters():
    param.requires_grad=False
    # print(f"{name}: {param.requires_grad}")
# Unfreeze some parameters
for i in range(12):
    model.transformer.h[i].ln_2.weight.requires_grad = True
    model.transformer.h[i].ln_2.bias.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.in_proj_weight.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.in_proj_bias.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.out_proj.weight.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.out_proj.bias.requires_grad = True
model.linear.weight.requires_grad = True
model.linear.bias.requires_grad = True

  model = create_fn(


model.safetensors:   0%|          | 0.00/1.30G [00:00<?, ?B/s]

In [62]:
model.load_state_dict(torch.load('/content/trainable_weights', map_location=device), strict=False)
print('Total params: ', sum(params.numel() for params in model.parameters() if params.requires_grad))

Total params:  29154048


#### inference

In [63]:
!gdown 1-obsrlcsth-FcJgr1SQ1QNJAq_yxBh9Z -O trainable_weights # epoch6

Downloading...
From: https://drive.google.com/uc?id=1-obsrlcsth-FcJgr1SQ1QNJAq_yxBh9Z
To: /content/trainable_weights
100% 117M/117M [00:00<00:00, 241MB/s]


In [None]:
model = Decoder(cfg).to(device)
model.load_state_dict(torch.load('/content/trainable_weights', map_location=device), strict=False)

In [None]:
evaluation_dict = {}
for data in tqdm(val_loader):
    img = data['img']
    file_name = data['filename']
    start_token = torch.tensor([[50256]])

    img = img.to(device)
    start_token = start_token.to(device)

    for i in range(250):
        with torch.no_grad():
            pred = model(start_token, img)
        out_token = pred.argmax(dim=2)[0][-1]
        start_token = torch.cat((start_token, out_token.unsqueeze(0).unsqueeze(0)), dim=1)
        end_token = torch.sum(start_token[0] == 50256).item()
        if end_token == 2:
            pred_token = start_token[start_token != 50256]
            pred_token = pred_token.tolist()
            pred_caption = encoding.decode(pred_token)
            break

    evaluation_dict[file_name[0]] = pred_caption
    print('\n', 'file name: ', file_name[0], '\caption: ', evaluation_dict[file_name[0]])

In [None]:
# Convert dictionary to JSON string
json_string = json.dumps(evaluation_dict, indent=2)  # The indent parameter is optional and adds indentation for better readability

# Write JSON string to a file
with open('/content/drive/MyDrive/NTU_DLCV/Hw3/p2_output/large_epoch6_output.json', 'w') as json_file:
    json_file.write(json_string)

# Write JSON string to a file
with open('output.json', 'w') as json_file:
    json_file.write(json_string)

#### CIDEr & CLIPScore

In [None]:
# !pip install git+https://github.com/openai/CLIP.git

In [None]:
# !python3 /content/p2_evaluate.py --pred_file /content/output.json --annotation_file /content/hw3_data/p2_data/val.json --images_root /content/hw3_data/p2_data/images/val