<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw3/p2_Image_caption.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm

In [2]:
import os
import re
import timm
import json
import torch
import numpy as np
import torchvision.transforms as tr

from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader

#### Download dataset and unzip zip file.

In [None]:
!gdown 11rP6KmR5Qwjhx0rfag0b5TZGBTRuPtQR -O hw3_data.zip
!unzip /content/hw3_data.zip

#### Tokenizer ('<|endoftext|>', 50256) -> 250dim

In [4]:
class BPETokenizer:

    def __init__(self, encoder_file, vocab_file):
        with open(encoder_file, 'r', encoding='utf-8') as f:
            self.encoder = json.load(f)
        self.decoder = {v:k for k,v in self.encoder.items()}
        with open(vocab_file, 'r', encoding='utf-8') as f:
            vocab = f.read().split('\n')[1:-1]
        self.bpe_ranks = {tuple(line.split()): i for i, line in enumerate(vocab)}
        assert len(self.encoder) == 50257 and len(self.bpe_ranks) == 49999 # len(self.bpe_ranks) == 50000
        bs = list(range(33, 127)) + list(range(161, 256))
        xs = list(range(0, 33)) + list(range(127, 161))
        cs = bs[:] + [2**8 + i for i in range(len(xs))]
        self.byte_encoder = dict(zip(bs + xs, [chr(n) for n in cs]))
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}

    def encode(self, text, allowed_special=None):
        tokens = re.findall(r"""<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d| ?""" +
                            r"""\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""", text, re.UNICODE)
        def translate(token):
            if token == '<|endoftext|>':
                assert allowed_special and token in allowed_special
                return [token]
            word = tuple(''.join(self.byte_encoder[byte] for byte in token.encode('utf-8')))
            while len(word) != 1:
                pairs = set((word[i], word[i+1]) for i in range(len(word)-1))
                bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
                if bigram not in self.bpe_ranks:
                    break
                a, b = bigram
                new_word = []
                i = 0
                while i < len(word):
                    j = word.index(a, i) if a in word[i:] else len(word)
                    new_word.extend(word[i:j])
                    i = j
                    if i < len(word):
                        j = 2 if i < len(word)-1 and word[i] == a and word[i+1] == b else 1
                        new_word.append(a+b if j == 2 else word[i])
                        i += j
                word = tuple(new_word)
            return word
        return [self.encoder[_] for token in tokens for _ in translate(token)]

    def decode(self, tokens):
        tokens = [self.decoder[token] for token in tokens]
        buffer = bytearray([self.byte_decoder[c] for c in ''.join(tokens)])
        return buffer.decode('utf-8', errors='replace')

In [139]:
encoding = BPETokenizer('/content/encoder.json', '/content/vocab.bpe')
prompt = 'a kitchen with a sink and many cooking machines and a pot of food'

text_embedding_len = 250

context = encoding.encode(prompt)
context = [50256] + context + [50256]*(text_embedding_len - len(context) - 1)
# encoding.decode(context)

list

#### Define function

In [44]:
def json_load(json_path: str):
    with open(json_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

In [45]:
def caption_with_id(json_path: str) -> list:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = [{'caption': row['caption'], 'image_id': row['image_id']} for row in json_data['annotations']]
    return data

In [46]:
def id2file_name(json_path: str) -> dict:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = {row['id']: row['file_name'] for row in json_data['images']}
    return data

In [57]:
encoder_joson_path = '/content/encoder.json'
vocab_bpe_path = '/content/vocab.bpe'
def collate_fn(batch, tokenizer=BPETokenizer(encoder_joson_path, vocab_bpe_path)):
    # Get the individual elements of the batch
    images = [item['img'] for item in batch]
    captions = [item['caption'] for item in batch]
    filenames = [item['filename'] for item in batch]

    # Tokenize captions
    tokenized_captions = [tokenizer.encode(caption) for caption in captions]

    # Pad the vector length into stop token to dimension 250
    text_len = 250 # text_embedding_len
    tokenized_captions = [
        [50256] + caption + [50256] * (text_len - len(caption) - 1) for caption in tokenized_captions
    ]

    # Convert tokenized captions to PyTorch tensors
    tokenized_captions = [torch.tensor(caption) for caption in tokenized_captions]

    # Create a new batch with tokenized captions
    tokenized_batch = {'img': torch.stack(images, dim=0), 'tokenized_captions': tokenized_captions, 'filename': filenames}

    return tokenized_batch

#### Build Dataset

In [58]:
class ImgCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, json_path, transform) -> None:
        super(ImgCaptionDataset, self).__init__()
        self.img_dir = img_dir
        self.transform = transform

        # Connect caption -> image_id -> file_name
        self.caption_with_id = caption_with_id(json_path)
        self.id2file_name = id2file_name(json_path)
    def __len__(self) -> int:
        return len(self.caption_with_id)

    def __getitem__(self, idx):
        caption_id = self.caption_with_id[idx]
        file_name = self.id2file_name[caption_id['image_id']]
        img = Image.open(os.path.join(self.img_dir, file_name)).convert('RGB')
        img = self.transform(img)
        return {'img': img, 'caption': caption_id['caption'], 'filename': os.path.splitext(file_name)[0]}

In [59]:
class ImgDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, transform) -> None:
        self.transform = transform
        self.img_path = [i for i in Path(root).glob("*.jpg")]

    def __len__(self) -> int:
        return len(self.img_path)

    def __getitem__(self, idx):
        img = Image.open(self.img_path[idx]).convert('RGB')
        img = self.transform(img)
        return img, os.path.splitext(self.img_path[idx].name)[0]

#### Build Dataloader

In [60]:
train_ds = ImgCaptionDataset(
    img_dir='/content/hw3_data/p2_data/images/train',
    json_path='/content/hw3_data/p2_data/train.json',
    transform=tr.Compose([
        tr.Resize(224),
        tr.CenterCrop(224),
        tr.ToTensor(),
        tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
)

trainloader = DataLoader(
    train_ds,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=1,
)

#### timm's ViT encoder

In [100]:
encoder = timm.create_model('vit_base_patch16_224_in21k', pretrained=True)

  model = create_fn(


In [101]:
for batch in trainloader:
    i = batch['img']
    t = batch['tokenized_captions']
    f = batch['filename']
    break

In [104]:
with torch.no_grad():
    output = encoder.forward_features(i)

In [105]:
output.shape # 14 * 14 + 1(start token) = 197

torch.Size([16, 197, 768])

#### decoder

In [71]:
import math
import collections
import torch
from torch import nn, Tensor
import torch.nn.functional as F

class Config:

    def __init__(self, checkpoint=None):
        self.n_layer = 12
        self.n_head = 12
        self.n_embd = 768
        self.vocab_size = 50257
        self.block_size = 1024
        self.checkpoint = checkpoint

class Attention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd)
        self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd)
        self.n_head = cfg.n_head
        self.n_embd = cfg.n_embd
        size = cfg.block_size
        self.register_buffer('bias', torch.tril(torch.ones(size, size)).view(1, 1, size, size))

    def forward(self, x):
        B, T, C = x.size() # batch, context, embedding
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        return self.c_proj((att @ v).transpose(1, 2).contiguous().view(B, T, C))

class CrossAttention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.n_head = cfg.n_head
        self.n_embd = cfg.n_embd

    def forward(self, x):
        multihead_attn = nn.MultiheadAttention(self.n_embd, self.n_head)
        # Q is the source from the decoder, K, V are the sources from the encoder.
        # Q: (N, L, Eq), where L is the target embedding dim, Eq is embed_dim and batch_first=True.
        # {K, V}: (N, L, E{k,v}), where L is the source embedding dim, E{k,v} is {k,v}_dim and batch_first=True.
        attn_output, attn_output_weights = multihead_attn(query, key, value)

class Block(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.ln_1 = nn.LayerNorm(cfg.n_embd)
        # self.ln_2 = nn.LayerNorm(cfg.n_embd)
        self.ln_3 = nn.LayerNorm(cfg.n_embd)
        self.attn = Attention(cfg)
        # self.crs_attn = CrossAttention(cfg)
        self.mlp = nn.Sequential(collections.OrderedDict([
            ('c_fc', nn.Linear(cfg.n_embd, 4 * cfg.n_embd)),
            ('act', nn.GELU(approximate='tanh')),
            ('c_proj', nn.Linear(4 * cfg.n_embd, cfg.n_embd))
        ]))

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        # x = x + self.crs_attn(self.ln_2(x))
        x = x + self.mlp(self.ln_3(x))
        return x

class Decoder(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.block_size = cfg.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(cfg.vocab_size, cfg.n_embd), # 文字投影
            wpe = nn.Embedding(cfg.block_size, cfg.n_embd), # position
            h = nn.Sequential(*[Block(cfg) for _ in range(cfg.n_layer)]), # Nx
            ln_f = nn.LayerNorm(cfg.n_embd)
        ))
        self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        # load checkpoint
        if self.cfg.checkpoint is not None:
            state_dict = torch.load(self.cfg.checkpoint)
            transposed = [ '.c_attn.weight', '.c_fc.weight', '.c_proj.weight' ]
            for key, value in state_dict.items():
                if any(key.endswith(w) for w in transposed):
                    state_dict[key] = value.t()
            self.transformer.load_state_dict(state_dict, strict=False)

    def forward(self, x: Tensor):
        x = torch.narrow(x, 1, 0, min(x.size(1), self.block_size))
        pos = torch.arange(x.size()[1], dtype=torch.long, device=x.device).unsqueeze(0)
        x = self.transformer.wte(x) + self.transformer.wpe(pos)
        x = self.lm_head(self.transformer.ln_f(self.transformer.h(x)))
        return x

In [106]:
cfg = Config(checkpoint='/content/hw3_data/p2_data/decoder_model.bin')
decoder = Decoder(cfg)

In [108]:
for batch in trainloader:
    t = batch['tokenized_captions']
    break

In [None]:
t[0]

In [152]:
unsqueeze_tokenized_captions = (t[1].unsqueeze(0))
unsqueeze_tokenized_captions.shape

torch.Size([1, 250])

In [153]:
output = decoder(unsqueeze_tokenized_captions)

In [154]:
output.shape

torch.Size([1, 250, 50257])

In [155]:
max_values, indices = torch.max(output, dim=-1)

In [157]:
encoding.decode(list(indices[0].numpy()))

'\n. who the missionboard, the surf,\nTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheTheThe'