<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw3/p2_Image_caption_huge_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install timm
!pip install loralib
!pip install peft

In [None]:
import os
import re
import math
import timm
import json
import torch
import collections
import numpy as np
import loralib as lora
import torch.nn.functional as F
import torchvision.transforms as tr

from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torch import nn, Tensor
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Download dataset and unzip zip file.

In [None]:
!gdown 1SUiRrG6zQVtyrVSVh9hOBq5_fX-oV2Lh -O hw3_data.zip # 11rP6KmR5Qwjhx0rfag0b5TZGBTRuPtQR
!unzip /content/hw3_data.zip

[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
  inflating: hw3_data/p2_data/images/train/000000226677.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000226677.jpg  
  inflating: hw3_data/p2_data/images/train/2832654970.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._2832654970.jpg  
  inflating: hw3_data/p2_data/images/train/000000343951.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000343951.jpg  
  inflating: hw3_data/p2_data/images/train/000000531622.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000531622.jpg  
  inflating: hw3_data/p2_data/images/train/000000176946.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000176946.jpg  
  inflating: hw3_data/p2_data/images/train/000000147543.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000147543.jpg  
  inflating: hw3_data/p2_data/images/train/000000273138.jpg  
  inflating: __MACOSX/hw3_data/p2_data/images/train/._000000273138.jpg  
  inflating: hw3_data/p2

#### Tokenizer ('<|endoftext|>', 50256) -> 250dim

In [None]:
class BPETokenizer:

    def __init__(self, encoder_file, vocab_file):
        with open(encoder_file, 'r', encoding='utf-8') as f:
            self.encoder = json.load(f)
        self.decoder = {v:k for k,v in self.encoder.items()}
        with open(vocab_file, 'r', encoding='utf-8') as f:
            vocab = f.read().split('\n')[1:-1]
        self.bpe_ranks = {tuple(line.split()): i for i, line in enumerate(vocab)}
        assert len(self.encoder) == 50257 and len(self.bpe_ranks) == 49999 # len(self.bpe_ranks) == 50000
        bs = list(range(33, 127)) + list(range(161, 256))
        xs = list(range(0, 33)) + list(range(127, 161))
        cs = bs[:] + [2**8 + i for i in range(len(xs))]
        self.byte_encoder = dict(zip(bs + xs, [chr(n) for n in cs]))
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}

    def encode(self, text, allowed_special=None):
        tokens = re.findall(r"""<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d| ?""" +
                            r"""\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""", text, re.UNICODE)
        def translate(token):
            if token == '<|endoftext|>':
                assert allowed_special and token in allowed_special
                return [token]
            word = tuple(''.join(self.byte_encoder[byte] for byte in token.encode('utf-8')))
            while len(word) != 1:
                pairs = set((word[i], word[i+1]) for i in range(len(word)-1))
                bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
                if bigram not in self.bpe_ranks:
                    break
                a, b = bigram
                new_word = []
                i = 0
                while i < len(word):
                    j = word.index(a, i) if a in word[i:] else len(word)
                    new_word.extend(word[i:j])
                    i = j
                    if i < len(word):
                        j = 2 if i < len(word)-1 and word[i] == a and word[i+1] == b else 1
                        new_word.append(a+b if j == 2 else word[i])
                        i += j
                word = tuple(new_word)
            return word
        return [self.encoder[_] for token in tokens for _ in translate(token)]

    def decode(self, tokens):
        tokens = [self.decoder[token] for token in tokens]
        buffer = bytearray([self.byte_decoder[c] for c in ''.join(tokens)])
        return buffer.decode('utf-8', errors='replace')

In [None]:
encoding = BPETokenizer('/content/encoder.json', '/content/vocab.bpe')

#### Define function

In [None]:
def json_load(json_path: str):
    with open(json_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

In [None]:
def caption_with_id(json_path: str) -> list:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = [{'caption': row['caption'], 'image_id': row['image_id']} for row in json_data['annotations']]
    return data

In [None]:
def id2file_name(json_path: str) -> dict:
    with open(json_path, 'r', encoding='utf-8') as file:
        json_data = json.load(file)
    data = {row['id']: row['file_name'] for row in json_data['images']}
    return data

In [None]:
encoder_joson_path = '/content/encoder.json'
vocab_bpe_path = '/content/vocab.bpe'
def collate_fn(batch, tokenizer=BPETokenizer(encoder_joson_path, vocab_bpe_path)):
    # Get the individual elements of the batch
    images = [item['img'] for item in batch]
    captions = [item['caption'] for item in batch]
    filenames = [item['filename'] for item in batch]

    # Tokenize captions
    tokenized_captions = [tokenizer.encode(caption) for caption in captions]

    # Pad the vector length into stop token to dimension 250
    text_len = 250 # text_embedding_len
    tokenized_captions_train = [
        [50256] + caption + [50256] * (text_len - len(caption) - 1) for caption in tokenized_captions
    ]
    tokenized_captions_inf = [
        caption + [50256] + [-100] * (text_len - len(caption) - 1) for caption in tokenized_captions
    ]

    # Convert tokenized captions to PyTorch tensors
    tokenized_captions_train = [torch.tensor(caption) for caption in tokenized_captions_train]
    tokenized_captions_inf = [torch.tensor(caption) for caption in tokenized_captions_inf]

    # Create a new batch with tokenized captions
    tokenized_batch = {
        'img': torch.stack(images, dim=0),
        'tokenized_captions_train': torch.stack(tokenized_captions_train, dim=0),
        'filename': filenames,
        'tokenized_captions_inf': torch.stack(tokenized_captions_inf, dim=0),
    }

    return tokenized_batch

#### Build Dataset

In [None]:
class ImgCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, json_path, transform) -> None:
        super(ImgCaptionDataset, self).__init__()
        self.img_dir = img_dir
        self.transform = transform

        # Connect caption -> image_id -> file_name
        self.caption_with_id = caption_with_id(json_path)
        self.id2file_name = id2file_name(json_path)
    def __len__(self) -> int:
        return len(self.caption_with_id)

    def __getitem__(self, idx):
        caption_id = self.caption_with_id[idx]
        file_name = self.id2file_name[caption_id['image_id']]
        img = Image.open(os.path.join(self.img_dir, file_name)).convert('RGB')
        img = self.transform(img)
        return {'img': img, 'caption': caption_id['caption'], 'filename': os.path.splitext(file_name)[0]}

#### Build Dataloader

In [None]:
train_ds = ImgCaptionDataset(
    img_dir='/content/hw3_data/p2_data/images/train',
    json_path='/content/hw3_data/p2_data/train.json',
    transform=tr.Compose([
        tr.Resize(224),
        tr.CenterCrop(224),
        tr.ToTensor(),
        tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
)
val_ds = ImgCaptionDataset(
    img_dir='/content/hw3_data/p2_data/images/val',
    json_path='/content/hw3_data/p2_data/val.json',
    transform=tr.Compose([
        tr.Resize(224),
        tr.CenterCrop(224),
        tr.ToTensor(),
        tr.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
)

train_loader = DataLoader(
    train_ds,
    batch_size=4,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=8,
)
val_loader = DataLoader(
    val_ds,
    batch_size=1,
    collate_fn=collate_fn,
    shuffle=True,
    num_workers=6,
)



#### Config

In [None]:
class Config:

    def __init__(self, checkpoint=None):
        self.n_layer = 12
        self.n_head = 12
        self.n_embd = 768
        self.vocab_size = 50257
        self.block_size = 1024
        self.checkpoint = checkpoint

In [None]:
cfg = Config(checkpoint='/content/hw3_data/p2_data/decoder_model.bin')

#### decoder

In [None]:
def LoRA_model(r=4, dropout=0.5, model='vit_base_patch16_224'):
    model_timm = timm.create_model(model, pretrained=True)
    config = LoraConfig(
        r=r,
        lora_alpha=16,
        target_modules=["qkv"],
        lora_dropout=dropout,
        bias="none",
        modules_to_save=["classifier"],
    )
    return get_peft_model(model_timm, config)

In [None]:
class Attention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.c_attn = nn.Linear(cfg.n_embd, 3 * cfg.n_embd)
        self.c_proj = nn.Linear(cfg.n_embd, cfg.n_embd)
        self.n_head = cfg.n_head
        self.n_embd = cfg.n_embd
        size = cfg.block_size
        self.register_buffer('bias', torch.tril(torch.ones(size, size)).view(1, 1, size, size))

    def forward(self, x):
        B, T, C = x.size() # batch, context, embedding
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        return self.c_proj((att @ v).transpose(1, 2).contiguous().view(B, T, C))

class CrossAttention(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(cfg.n_embd, cfg.n_head, batch_first=True)

    def forward(self, query, encoder_out):
        """
        Q is the source from the decoder, K, V are the sources from the encoder.
        Q: (N, L, Eq), where L is the target embedding dim, Eq is embed_dim and batch_first=True.
        {K, V}: (N, L, E{k,v}), where L is the source embedding dim, E{k,v} is {k,v}_dim and batch_first=True.
        """
        attn_output, attn_output_weights = self.multihead_attn(query, encoder_out, encoder_out)
        return attn_output #, attn_output_weights

class Block(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.ln_1 = nn.LayerNorm(cfg.n_embd)
        self.ln_2 = nn.LayerNorm(cfg.n_embd) # add
        self.ln_3 = nn.LayerNorm(cfg.n_embd)
        self.attn = Attention(cfg)
        self.crs_attn = CrossAttention(cfg) # add
        self.mlp = nn.Sequential(collections.OrderedDict([
            ('c_fc', nn.Linear(cfg.n_embd, 4 * cfg.n_embd)),
            ('act', nn.GELU(approximate='tanh')),
            ('c_proj', nn.Linear(4 * cfg.n_embd, cfg.n_embd))
        ]))

    def forward(self, x, encoder_out) -> Tensor: # add
        x = x + self.attn(self.ln_1(x))
        cross_x = self.crs_attn(self.ln_2(x), self.ln_2(encoder_out)) # add , weights
        x = cross_x + x
        x = x + self.mlp(self.ln_3(x))
        return x #, weights

class Decoder(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.block_size = cfg.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(cfg.vocab_size, cfg.n_embd), # 文字投影
            wpe = nn.Embedding(cfg.block_size, cfg.n_embd), # position
            h = nn.Sequential(*[Block(cfg) for _ in range(cfg.n_layer)]), # Nx
            ln_f = nn.LayerNorm(cfg.n_embd)
        ))
        self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        # timm's ViT encoder (vit_base_patch16_224_in21k, vit_large_patch16_224_in21k, )
        # self.encoder = timm.create_model('vit_huge_patch14_224_in21k', pretrained=True)
        self.encoder = LoRA_model(r=4, dropout=0.5, model='vit_huge_patch14_224_in21k')
        self.linear = nn.Linear(1280, cfg.n_embd) # [16, 257, 1280]
        # load checkpoint
        if self.cfg.checkpoint is not None:
            state_dict = torch.load(self.cfg.checkpoint)
            transposed = [ '.c_attn.weight', '.c_fc.weight', '.c_proj.weight' ]
            for key, value in state_dict.items():
                if any(key.endswith(w) for w in transposed):
                    state_dict[key] = value.t()
            self.transformer.load_state_dict(state_dict, strict=False)

    def forward(self, x: Tensor, img: Tensor) -> Tensor: # add
        x = torch.narrow(x, 1, 0, min(x.size(1), self.block_size))
        pos = torch.arange(x.size()[1], dtype=torch.long, device=x.device).unsqueeze(0)
        x = self.transformer.wte(x) + self.transformer.wpe(pos)
        with torch.no_grad():
            encoder_out = self.encoder.forward_features(img)
        for block in self.transformer.h:
            x = block(x, self.linear(encoder_out)) # , weights
        x = self.lm_head(self.transformer.ln_f(x)) # add
        return x #, weights

#### Training

In [None]:
def training(dataloader, model, loss_fn, optimizer):

    size = len(dataloader.dataset) # number of samples
    num_batches = len(dataloader) # batches per epoch
    epoch_loss = 0

    model.train() # to training mode
    for batch_i, data in enumerate(tqdm(dataloader)):
        data['img'] = data['img'].to(device, non_blocking=True)
        data['tokenized_captions_train'] = data['tokenized_captions_train'].to(device, non_blocking=True)
        data['tokenized_captions_inf'] = data['tokenized_captions_inf'].to(device, non_blocking=True)

        # zero the parameter gradients
        optimizer.zero_grad()

        # Compute prediction loss
        pred = model(data['tokenized_captions_train'], data['img'])
        # reshape to (B, C)
        data['tokenized_captions_inf'] = data['tokenized_captions_inf'].reshape(-1)
        pred = pred.reshape(-1, 50257)
        loss = loss_fn(pred, data['tokenized_captions_inf']) # tokenized captions inf

        # Optimization by gradients
        loss.backward() # backpropagation to compute gradients
        optimizer.step() # update model params

        # write to logs
        epoch_loss += loss.item() # tensor -> python value
    return epoch_loss/num_batches

##### Freeze parameters

In [None]:
model = Decoder(cfg).to(device)

# Freeze parameters
for name, param in model.named_parameters():
    param.requires_grad=False

# Unfreeze some parameters
for i in range(12):
    model.transformer.h[i].ln_2.weight.requires_grad = True
    model.transformer.h[i].ln_2.bias.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.in_proj_weight.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.in_proj_bias.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.out_proj.weight.requires_grad = True
    model.transformer.h[i].crs_attn.multihead_attn.out_proj.bias.requires_grad = True
for i in range(32):
    model.encoder.base_model.model.blocks[0].attn.qkv.lora_A.default
    model.encoder.base_model.model.blocks[i].attn.qkv.lora_A.default.weight.requires_grad = True
    model.encoder.base_model.model.blocks[i].attn.qkv.lora_B.default.weight.requires_grad = True
model.linear.weight.requires_grad = True
model.linear.bias.requires_grad = True

trainable_weights = [name for name, param in model.named_parameters() if param.requires_grad == True]
# list for True
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

In [None]:
EPOCHS = 10
lr=3e-4
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# logs
logs = {
    'train_loss': []
}

for epoch in tqdm(range(EPOCHS)):
    train_loss = training(train_loader, model, loss_fn, optimizer)

    print(f'EPOCH: {epoch:04d} \train_loss: {train_loss:.4f}')

    logs['train_loss'].append(train_loss)

    # Save model
    save_weights = {k: v for k, v in model.state_dict().items() if k in trainable_weights}
    torch.save(save_weights, f'/content/drive/MyDrive/NTU_DLCV/Hw3/p2_ckpt_huge_LoRA/trainable_weights_lr{lr}_epoch{epoch}_{train_loss:.4f}.pth')
    print('---------- Model Save ----------')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:55<20:17,  4.11it/s][A
 62%|██████▏   | 8264/13262 [33:55<20:37,  4.04it/s][A
 62%|██████▏   | 8265/13262 [33:55<20:28,  4.07it/s][A
 62%|██████▏   | 8266/13262 [33:55<20:35,  4.04it/s][A
 62%|██████▏   | 8267/13262 [33:56<20:35,  4.04it/s][A
 62%|██████▏   | 8268/13262 [33:56<20:37,  4.03it/s][A
 62%|██████▏   | 8269/13262 [33:56<20:57,  3.97it/s][A
 62%|██████▏   | 8270/13262 [33:56<20:46,  4.00it/s][A
 62%|██████▏   | 8271/13262 [33:57<20:23,  4.08it/s][A
 62%|██████▏   | 8272/13262 [33:57<20:26,  4.07it/s][A
 62%|██████▏   | 8273/13262 [33:57<20:28,  4.06it/s][A
 62%|██████▏   | 8274/13262 [33:57<20:19,  4.09it/s][A
 62%|██████▏   | 8275/13262 [33:58<20:28,  4.06it/s][A
 62%|██████▏   | 8276/13262 [33:58<20:30,  4.05it/s][A
 62%|██████▏   | 8277/13262 [33:58<20:11,  4.11it/s][A
 62%|██████▏   | 8278/13262 [33:58<20:14,  4.10it/s][A
 62%|██████▏   | 8279/13262 [33:59<20:3

EPOCH: 0000 	rain_loss: 3.1683


 10%|█         | 1/10 [54:29<8:10:22, 3269.18s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:53<20:29,  4.07it/s][A
 62%|██████▏   | 8264/13262 [33:54<20:16,  4.11it/s][A
 62%|██████▏   | 8265/13262 [33:54<20:16,  4.11it/s][A
 62%|██████▏   | 8266/13262 [33:54<20:11,  4.12it/s][A
 62%|██████▏   | 8267/13262 [33:54<20:06,  4.14it/s][A
 62%|██████▏   | 8268/13262 [33:55<20:21,  4.09it/s][A
 62%|██████▏   | 8269/13262 [33:55<20:14,  4.11it/s][A
 62%|██████▏   | 8270/13262 [33:55<20:13,  4.11it/s][A
 62%|██████▏   | 8271/13262 [33:55<20:30,  4.06it/s][A
 62%|██████▏   | 8272/13262 [33:56<20:54,  3.98it/s][A
 62%|██████▏   | 8273/13262 [33:56<20:40,  4.02it/s][A
 62%|██████▏   | 8274/13262 [33:56<20:28,  4.06it/s][A
 62%|██████▏   | 8275/13262 [33:56<20:49,  3.99it/s][A
 62%|██████▏   | 8276/13262 [33:57<20:32,  4.04it/s][A
 62%|██████▏   | 8277/13262 [33:57<20:28,  4.06it/s][A
 62%|██████▏   | 8278/13262 [33:57<20:16,  4.10it/s][A
 62%|██████▏   | 8279/13262 [33:57<20:1

EPOCH: 0001 	rain_loss: 2.6264


 20%|██        | 2/10 [1:48:54<7:15:35, 3266.97s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:55<20:12,  4.12it/s][A
 62%|██████▏   | 8264/13262 [33:55<20:01,  4.16it/s][A
 62%|██████▏   | 8265/13262 [33:55<20:03,  4.15it/s][A
 62%|██████▏   | 8266/13262 [33:55<20:00,  4.16it/s][A
 62%|██████▏   | 8267/13262 [33:56<20:09,  4.13it/s][A
 62%|██████▏   | 8268/13262 [33:56<20:12,  4.12it/s][A
 62%|██████▏   | 8269/13262 [33:56<20:03,  4.15it/s][A
 62%|██████▏   | 8270/13262 [33:56<19:59,  4.16it/s][A
 62%|██████▏   | 8271/13262 [33:57<20:03,  4.15it/s][A
 62%|██████▏   | 8272/13262 [33:57<20:24,  4.08it/s][A
 62%|██████▏   | 8273/13262 [33:57<20:12,  4.11it/s][A
 62%|██████▏   | 8274/13262 [33:57<20:01,  4.15it/s][A
 62%|██████▏   | 8275/13262 [33:58<20:16,  4.10it/s][A
 62%|██████▏   | 8276/13262 [33:58<20:11,  4.12it/s][A
 62%|██████▏   | 8277/13262 [33:58<20:38,  4.02it/s][A
 62%|██████▏   | 8278/13262 [33:58<20:30,  4.05it/s][A
 62%|██████▏   | 8279/13262 [33:59<20:1

EPOCH: 0002 	rain_loss: 2.3981


 30%|███       | 3/10 [2:43:20<6:21:06, 3266.64s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:53<20:24,  4.08it/s][A
 62%|██████▏   | 8264/13262 [33:53<20:19,  4.10it/s][A
 62%|██████▏   | 8265/13262 [33:54<20:11,  4.12it/s][A
 62%|██████▏   | 8266/13262 [33:54<20:29,  4.06it/s][A
 62%|██████▏   | 8267/13262 [33:54<20:23,  4.08it/s][A
 62%|██████▏   | 8268/13262 [33:54<20:21,  4.09it/s][A
 62%|██████▏   | 8269/13262 [33:55<20:32,  4.05it/s][A
 62%|██████▏   | 8270/13262 [33:55<20:32,  4.05it/s][A
 62%|██████▏   | 8271/13262 [33:55<20:42,  4.02it/s][A
 62%|██████▏   | 8272/13262 [33:55<20:38,  4.03it/s][A
 62%|██████▏   | 8273/13262 [33:56<20:31,  4.05it/s][A
 62%|██████▏   | 8274/13262 [33:56<20:34,  4.04it/s][A
 62%|██████▏   | 8275/13262 [33:56<20:24,  4.07it/s][A
 62%|██████▏   | 8276/13262 [33:56<20:25,  4.07it/s][A
 62%|██████▏   | 8277/13262 [33:57<20:33,  4.04it/s][A
 62%|██████▏   | 8278/13262 [33:57<20:23,  4.07it/s][A
 62%|██████▏   | 8279/13262 [33:57<20:1

EPOCH: 0003 	rain_loss: 2.2040


 40%|████      | 4/10 [3:37:45<5:26:34, 3265.73s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:55<20:28,  4.07it/s][A
 62%|██████▏   | 8264/13262 [33:55<20:15,  4.11it/s][A
 62%|██████▏   | 8265/13262 [33:55<20:10,  4.13it/s][A
 62%|██████▏   | 8266/13262 [33:56<20:33,  4.05it/s][A
 62%|██████▏   | 8267/13262 [33:56<20:34,  4.04it/s][A
 62%|██████▏   | 8268/13262 [33:56<20:45,  4.01it/s][A
 62%|██████▏   | 8269/13262 [33:56<20:36,  4.04it/s][A
 62%|██████▏   | 8270/13262 [33:57<20:50,  3.99it/s][A
 62%|██████▏   | 8271/13262 [33:57<20:34,  4.04it/s][A
 62%|██████▏   | 8272/13262 [33:57<20:18,  4.09it/s][A
 62%|██████▏   | 8273/13262 [33:57<20:11,  4.12it/s][A
 62%|██████▏   | 8274/13262 [33:58<20:28,  4.06it/s][A
 62%|██████▏   | 8275/13262 [33:58<20:37,  4.03it/s][A
 62%|██████▏   | 8276/13262 [33:58<20:28,  4.06it/s][A
 62%|██████▏   | 8277/13262 [33:58<20:21,  4.08it/s][A
 62%|██████▏   | 8278/13262 [33:58<20:15,  4.10it/s][A
 62%|██████▏   | 8279/13262 [33:59<20:2

EPOCH: 0004 	rain_loss: 2.0279


 50%|█████     | 5/10 [4:32:12<4:32:11, 3266.35s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [33:56<20:20,  4.10it/s][A
 62%|██████▏   | 8264/13262 [33:56<20:20,  4.09it/s][A
 62%|██████▏   | 8265/13262 [33:56<20:30,  4.06it/s][A
 62%|██████▏   | 8266/13262 [33:57<20:20,  4.09it/s][A
 62%|██████▏   | 8267/13262 [33:57<20:13,  4.11it/s][A
 62%|██████▏   | 8268/13262 [33:57<20:18,  4.10it/s][A
 62%|██████▏   | 8269/13262 [33:57<20:18,  4.10it/s][A
 62%|██████▏   | 8270/13262 [33:58<20:08,  4.13it/s][A
 62%|██████▏   | 8271/13262 [33:58<20:21,  4.09it/s][A
 62%|██████▏   | 8272/13262 [33:58<20:26,  4.07it/s][A
 62%|██████▏   | 8273/13262 [33:58<20:22,  4.08it/s][A
 62%|██████▏   | 8274/13262 [33:59<20:32,  4.05it/s][A
 62%|██████▏   | 8275/13262 [33:59<20:29,  4.06it/s][A
 62%|██████▏   | 8276/13262 [33:59<20:20,  4.09it/s][A
 62%|██████▏   | 8277/13262 [33:59<20:22,  4.08it/s][A
 62%|██████▏   | 8278/13262 [34:00<20:22,  4.08it/s][A
 62%|██████▏   | 8279/13262 [34:00<20:2

EPOCH: 0005 	rain_loss: 1.8596


 60%|██████    | 6/10 [5:26:43<3:37:51, 3267.76s/it]

---------- Model Save ----------


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 62%|██████▏   | 8263/13262 [34:01<20:18,  4.10it/s][A
 62%|██████▏   | 8264/13262 [34:02<20:15,  4.11it/s][A
 62%|██████▏   | 8265/13262 [34:02<20:23,  4.08it/s][A
 62%|██████▏   | 8266/13262 [34:02<20:33,  4.05it/s][A
 62%|██████▏   | 8267/13262 [34:02<20:24,  4.08it/s][A
 62%|██████▏   | 8268/13262 [34:03<20:34,  4.05it/s][A
 62%|██████▏   | 8269/13262 [34:03<20:39,  4.03it/s][A
 62%|██████▏   | 8270/13262 [34:03<20:40,  4.02it/s][A
 62%|██████▏   | 8271/13262 [34:03<20:40,  4.02it/s][A
 62%|██████▏   | 8272/13262 [34:04<20:42,  4.01it/s][A
 62%|██████▏   | 8273/13262 [34:04<20:43,  4.01it/s][A
 62%|██████▏   | 8274/13262 [34:04<20:42,  4.01it/s][A
 62%|██████▏   | 8275/13262 [34:04<20:50,  3.99it/s][A
 62%|██████▏   | 8276/13262 [34:05<20:49,  3.99it/s][A
 62%|██████▏   | 8277/13262 [34:05<20:49,  3.99it/s][A
 62%|██████▏   | 8278/13262 [34:05<20:46,  4.00it/s][A
 62%|██████▏   | 8279/13262 [34:05<21:0

EPOCH: 0006 	rain_loss: 1.7048


 70%|███████   | 7/10 [6:21:22<2:43:34, 3271.57s/it]

---------- Model Save ----------


[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
 46%|████▌     | 6126/13262 [25:15<29:12,  4.07it/s][A
 46%|████▌     | 6127/13262 [25:15<29:29,  4.03it/s][A
 46%|████▌     | 6128/13262 [25:15<28:56,  4.11it/s][A
 46%|████▌     | 6129/13262 [25:16<28:46,  4.13it/s][A
 46%|████▌     | 6130/13262 [25:16<28:35,  4.16it/s][A
 46%|████▌     | 6131/13262 [25:16<28:33,  4.16it/s][A
 46%|████▌     | 6132/13262 [25:16<28:51,  4.12it/s][A
 46%|████▌     | 6133/13262 [25:17<28:54,  4.11it/s][A
 46%|████▋     | 6134/13262 [25:17<28:45,  4.13it/s][A
 46%|████▋     | 6135/13262 [25:17<28:33,  4.16it/s][A
 46%|████▋     | 6136/13262 [25:17<28:30,  4.17it/s][A
 46%|████▋     | 6137/13262 [25:18<28:57,  4.10it/s][A
 46%|████▋     | 6138/13262 [25:18<29:13,  4.06it/s][A
 46%|████▋     | 6139/13262 [25:18<28:56,  4.10it/s][A
 46%|████▋     | 6140/13262 [25:18<29:19,  4.05it/s][A
 46%|████▋     | 6141/13262 [25:19<29:25,  4.03it/s][A
 46%|████▋     | 6142/13262 [25:19<29:22,  4.04it/s][A
 46%|████▋   

#### Check the model params less than 35M

In [None]:
# !gdown 1-BEb6UWp_WwdfE7T-TulrVUTkPB9ZFIz -O trainable_weights

In [None]:
# model = Decoder(cfg)
# # Freeze parameters
# for name, param in model.named_parameters():
#     param.requires_grad=False
#     # print(f"{name}: {param.requires_grad}")
# # Unfreeze some parameters
# for i in range(12):
#     model.transformer.h[i].ln_2.weight.requires_grad = True
#     model.transformer.h[i].ln_2.bias.requires_grad = True
#     model.transformer.h[i].attn.c_attn.lora_A.requires_grad = True
#     model.transformer.h[i].attn.c_attn.lora_B.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.in_proj_weight.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.in_proj_bias.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.out_proj.weight.requires_grad = True
#     model.transformer.h[i].crs_attn.multihead_attn.out_proj.bias.requires_grad = True
# model.linear.weight.requires_grad = True
# model.linear.bias.requires_grad = True

In [None]:
# model.load_state_dict(torch.load('/content/trainable_weights', map_location=device), strict=False)
# print('Total params: ', sum(params.numel() for params in model.parameters() if params.requires_grad))

#### inference

In [None]:
# !gdown 1-Ko5lsfct2QZ2zdUKZjkJPS4ED-APNNc -O trainable_weights0
# !gdown 1-2688GNZFFLlP4eUKq_QlEC-9rUZtVFF -O trainable_weights1
# !gdown 1-33lmiZiCDRASXsfdkgy9VjzY-DpOWeg -O trainable_weights2
# !gdown 1-4N_nCAyS5LAFnGaxUfstzcsaDnQydKP -O trainable_weights3
# !gdown 1-8wfeBKbkfSOgvysv-x3u6w3om527sBM -O trainable_weights4
# !gdown 1-8wfeBKbkfSOgvysv-x3u6w3om527sBM -O trainable_weights5
# !gdown 1-BEb6UWp_WwdfE7T-TulrVUTkPB9ZFIz -O trainable_weights6

In [None]:
# for i in range(6, 3, -1):
#     model = Decoder(cfg).to(device)
#     model.load_state_dict(torch.load(f'/content/trainable_weights{i}', map_location=device), strict=False)
#     print(f'---------- trainable weights {i} is using ----------')

#     evaluation_dict = {}
#     for data in tqdm(val_loader):
#         img = data['img'].to(device)
#         file_name = data['filename']
#         start_token = torch.tensor([[50256]]).to(device)

#         for j in range(250):
#             with torch.no_grad():
#                 pred = model(start_token, img) # , weights
#                 # print(weights.size())

#             out_token = pred.argmax(dim=2)[0][-1]
#             start_token = torch.cat((start_token, out_token.unsqueeze(0).unsqueeze(0)), dim=1)
#             end_token = torch.sum(start_token[0] == 50256).item()
#             if end_token == 2:
#                 pred_token = start_token[start_token != 50256]
#                 pred_token = pred_token.tolist()
#                 pred_caption = encoding.decode(pred_token)
#                 break

#         evaluation_dict[file_name[0]] = pred_caption
#         print('\n', 'file name: ', file_name[0], '\caption: ', evaluation_dict[file_name[0]])

#     json_string = json.dumps(evaluation_dict, indent=2)  # The indent parameter is optional and adds indentation for better readability
#     with open(f'/content/drive/MyDrive/NTU_DLCV/Hw3/p2_output_huge_LoRA/huge_epoch{i}_output.json', 'w') as json_file:
#         json_file.write(json_string)
#     print(f'---------- Epoch{i} huge params Saved ----------')