Dùng tạm mô hình tự code để cho nó đúng đầu vào đầu ra đã

In [None]:
import numpy as np
from numpy import array
import pandas as pd
import matplotlib.pyplot as plt
import string
import os
from PIL import Image
import glob
from pickle import dump, load
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import math

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
BATCH_SIZE = 64 # 3.2GB VRAM f32 (2.9 Dedicated + 0.2 Shared)
device = 'cuda'


In [2]:
image_transforms = transforms.Compose([
    transforms.Resize((324, 324)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Dùng cho ImageNet
])

In [3]:
import util
data_folder_path = "Flickr8k/Flicker8k_Dataset"
train_dataset = util.Flickr8kDataset(
    data_folder_path,
    "Data_bert/train_set_bert.pkl",
    image_transforms,
    device=device
)
trainloader = train_dataset.get_dataloader(batch_size=BATCH_SIZE, num_workers=2, shuffle=False)
test_dataset = util.Flickr8kDataset(
    data_folder_path,
    "Data_bert/test_set_bert.pkl",
    image_transforms,
    device=device
)
testloader = test_dataset.get_dataloader(batch_size=BATCH_SIZE, num_workers=2, shuffle=False)
print(len(train_dataset)) # :)) sao có tận 30k ids ảnh trong train_set_bert.pkl, một ảnh có nhiều caption à ?
print(len(test_dataset)) 
sample_image, sample_caption = test_dataset[0]
print(sample_image.shape, sample_caption.shape)

6000
1000
30000
5000
torch.Size([3, 324, 324]) torch.Size([46])


In [4]:
class EncoderCNNTest(nn.Module):
    def __init__(self, output_size: int):
        super(EncoderCNN, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=3), # 324 * 324
            nn.LeakyReLU(0.1),
            nn.Conv2d(8, 16, kernel_size=3, stride=3), # 108 * 108
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 32, kernel_size=3, stride=1), # 36 * 36
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=3, stride=3), # 12 * 12
            nn.Conv2d(32, 64, kernel_size=3, stride=1), # 10 * 10
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=3, stride=1), # 8* 8 ?? 7 * 7
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, out_features=512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, output_size)
        )
    def forward(self, images: torch.Tensor):
        return self.CNN(images)
class DecoderRNNTest(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
    def forward(self, features: torch.Tensor, captions: torch.Tensor, hiden_state: tuple[torch.Tensor, torch.Tensor]): 
        # features : image_features : [bsz, embed]
        # captions : [bsz, seq]
        # hidden : [1, bsz, embed]
        # print("Input dim")
        # print(features.shape)
        # print(captions.shape)
        # print(hiden_state[0].shape)
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((embeddings, features), dim=2) # [bsz, seq, embed*2]
        # print(combined.shape, hiden_state[0].shape, hiden_state[1].shape)
        output, hidden = self.lstm(combined, hiden_state) # [bsz, seq, hid]
        output = self.relu(output)
        output = self.linear(output) #[batch_size, seq_len, vocab_size]
        # print("End")

        return output, hidden
class EncoderCNN(nn.Module):
    def __init__(self, output_size: int):
        super(EncoderCNN, self).__init__()
        self.inception_model = models.inception_v3(pretrained=True)
        #self.inception_model.fc = torch.nn.Identity()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1000, output_size)
        for name, param in self.inception_model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    def forward(self, images: torch.Tensor):
        # print("encoder_model")
        features = self.inception_model(images) #[1, 2048]
        if isinstance(features, tuple):  # Nếu là tuple
            features = features[0] 
        features = self.relu(features)
        features = self.dropout(features)
        features = self.fc(features)
        # print("encoder_output.shape :", features.shape)
        return features #[1, 256]
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, features: torch.Tensor, captions: torch.Tensor, hiden_state: tuple[torch.Tensor, torch.Tensor]): 
        # features : image_features : [bsz, embed]
        # captions : [bsz, seq]
        # hidden : [1, bsz, embed]
        # print("Input dim")
        # print(features.shape)
        # print(captions.shape)
        # print(hiden_state[0].shape)
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((embeddings, features), dim=2) # [bsz, seq, embed*2]
        # print(combined.shape, hiden_state[0].shape, hiden_state[1].shape)
        output, hidden = self.lstm(combined, hiden_state) # [bsz, seq, hid]
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear(output) #[batch_size, seq_len, vocab_size]
        # print("End")

        return output, hidden


In [5]:
# #Test Encoder
# encoder = EncoderCNN(256).to(device)
# encoder.eval()
# output_en = encoder(train_dataset[0][0].unsqueeze(0))
# print(output_en.shape)
# print(tokenizer.convert_ids_to_tokens(101))

In [6]:
# #test decoder
# vocab_size = tokenizer.vocab_size
# embed_size = 16
# num_layers = 1
# hidden_size = 128
# hiden = torch.zeros(1, 1, hidden_size).to(device)
# cell = torch.zeros(1, 1, hidden_size).to(device)
# hiden_state = (hiden, cell)
# hiden_state[0].shape
# decoder = DecoderRNN(
#     embed_size=embed_size, 
#     vocab_size=vocab_size, 
#     hidden_size=hidden_size, 
#     input_size=256,
#     num_layers=num_layers
# ).to(device)
# decoder.eval()
# state = torch.tensor([101, 102, 103]).unsqueeze(0).to(device)
# # print(output_en.shape)
# # print(state.shape)
# # print(hiden_state[0].shape, hiden_state[1].shape)
# outputs, hidden = decoder(output_en, state, hiden_state)
# # print(outputs.shape)
# # print(hidden[0].shape)
# # print(outputs.shape)


In [None]:
def train_single_batch(
        encoder: nn.Module,
        decoder: nn.Module,
        images: torch.Tensor,
        captions: torch.Tensor,
        encoder_optimizer: torch.optim.Optimizer,
        decoder_optimizer: torch.optim.Optimizer,
        criterion: callable,
        seq_length: int,
        hidden_size: int
):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    encoder.train()
    decoder.train()

    end_seq = torch.tensor([102])
    bsz = images.shape[0]
    # with torch.autocast(device_type=device, dtype=torch.bfloat16):
    encoder_output: torch.Tensor = encoder(images) # [bsz, 256]
    hidden = torch.zeros(1, bsz, hidden_size).to(device)
    cell = torch.zeros(1, bsz, hidden_size).to(device)
    hidden_state = (hidden, cell)
    total_loss = 0
    predicts = []
    decoder_input = torch.tensor([101]).expand(hidden.shape[1]).unsqueeze(1).to(device)
    count = 0
    for i in range(1, seq_length):
        # print("Start")
        # print(encoder_output.shape)
        # print(decoder_input.shape)
        # print(hidden_state[0].shape)
        # with torch.autocast(device_type=device, dtype=torch.bfloat16):
        decoder_output, hidden_state = decoder(encoder_output, decoder_input, hidden_state)
        decoder_output: torch.Tensor
        hidden_state: tuple[torch.Tensor, torch.Tensor]
        # print(decoder_output.shape)
        # print(hidden_state[0].shape)
        loss: torch.Tensor = criterion(decoder_output.squeeze(1), captions[:, i])
        loss.backward(retain_graph=True) # Accumulate grad
        total_loss += loss.item()
        count += 1
        predicted = decoder_output.argmax(dim=2)
        # print(predicted.shape)
        for j in range(predicted.shape[0]-1, -1, -1): # Drop data nếu nó ra end token
            if predicted[j].item() == end_seq.item():
                filter_mask = torch.zeros_like(hidden_state[0])
                # print(filter_mask.shape)
                filter_mask[:,j] = False
                hidden = hidden_state[0][filter_mask.to(torch.bool)]
                cell = hidden_state[1][filter_mask.to(torch.bool)]
                hidden_state = (hidden, cell)
                filter_mask = torch.zeros_like(captions)
                filter_mask[j] = False
                captions = captions[filter_mask.to(torch.bool)]
                filter_mask = torch.zeros_like(encoder_output)
                filter_mask[j] = False
                encoder_output = encoder_output[filter_mask.to(torch.bool)]
        if encoder_output.shape[0] == 0:
            break
        decoder_input = captions[:, i].unsqueeze(1)
    decoder_optimizer.step()
    encoder_optimizer.step()
    return total_loss/count if count != 0 else 0

In [8]:
encoder = EncoderCNN(
    output_size=256
)
decoder = DecoderRNN(
    embed_size=16,
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,
    input_size=256,
    num_layers=1
)
encoder.to(device)
decoder.to(device)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)
lossf = nn.CrossEntropyLoss(ignore_index=0)
num_epochs = 1
# print(encoder)
# print(decoder)
for epoch in tqdm.trange(num_epochs):
    i = 0
    total_loss = 0
    count = 0
    for images, captions in trainloader:
        loss = train_single_batch(
            encoder=encoder,
            decoder=decoder,
            images=images,
            captions=captions,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            criterion=lossf,
            seq_length=32,
            hidden_size=128
        )
        total_loss += loss if not math.isnan(loss) else 0
        count += 1 if not math.isnan(loss) else 0
        i+=1
        # print("Finish batch")
        print(f"{loss:.5f} | {i}/{len(trainloader)}")
    print(f"Epoch {epoch+1} | Test loss : {total_loss/count}")
    # random : log(1/30k) ~ 10.31

  0%|          | 0/1 [00:00<?, ?it/s]

nan | 1/469
nan | 2/469
10.33457 | 3/469
10.21784 | 4/469
10.08265 | 5/469
9.90005 | 6/469
9.64418 | 7/469
9.49579 | 8/469
9.29390 | 9/469
9.16287 | 10/469
8.98156 | 11/469
8.85970 | 12/469
8.68585 | 13/469
8.53561 | 14/469
8.43100 | 15/469
8.26295 | 16/469
8.14614 | 17/469
7.92569 | 18/469
7.79030 | 19/469
7.63027 | 20/469
7.43530 | 21/469
7.35094 | 22/469
7.13353 | 23/469
6.99972 | 24/469
6.83039 | 25/469
6.67528 | 26/469
6.52067 | 27/469
6.41031 | 28/469
6.25714 | 29/469
6.07256 | 30/469
5.96062 | 31/469
5.82471 | 32/469
5.62772 | 33/469
5.53043 | 34/469
5.40219 | 35/469
5.21612 | 36/469
5.13587 | 37/469
4.97039 | 38/469
4.89724 | 39/469
4.82156 | 40/469
4.64891 | 41/469
4.52030 | 42/469
4.48322 | 43/469
4.40249 | 44/469
4.28216 | 45/469
4.19617 | 46/469
4.11442 | 47/469
4.04193 | 48/469
3.98858 | 49/469
3.91588 | 50/469
3.87367 | 51/469
3.80503 | 52/469
3.74158 | 53/469
3.67746 | 54/469
3.65070 | 55/469
3.63548 | 56/469
3.60014 | 57/469
3.57552 | 58/469
3.51101 | 59/469
3.52677 | 6

100%|██████████| 1/1 [07:08<00:00, 428.06s/it]

Epoch 1 | Test loss : 3.45287829012949





In [22]:
def test_sample(
        encoder: nn.Module,
        decoder: nn.Module,
        dataset: Dataset,
        index: int,
        seq_length: int,
        hidden_size: int
):
    images, captions = dataset[index]
    images: torch.Tensor
    captions: torch.Tensor
    images = images.unsqueeze(0)
    captions = captions.unsqueeze(0)
    encoder.eval()
    decoder.eval()
    end_seq = torch.tensor([102])
    bsz = images.shape[0]
    # with torch.autocast(device_type=device, dtype=torch.bfloat16):
    encoder_output: torch.Tensor = encoder(images) # [bsz, 256]
    hidden = torch.zeros(1, bsz, hidden_size).to(device)
    cell = torch.zeros(1, bsz, hidden_size).to(device)
    hidden_state = (hidden, cell)
    total_loss = 0
    predicts = [[101] for _ in range(bsz)]
    decoder_input = torch.tensor([101]).expand(hidden.shape[1]).unsqueeze(1).to(device)
    count = 0
    for i in range(1, seq_length):
        decoder_output, hidden_state = decoder(encoder_output, decoder_input, hidden_state)
        decoder_output: torch.Tensor
        hidden_state: tuple[torch.Tensor, torch.Tensor]
        count += 1
        predicted = decoder_output.argmax(dim=2)
        for j in range(predicted.shape[0]-1, -1, -1):
            predicts[j].append(predicted[j].item())
            if predicted[j].item() == end_seq.item():
                filter_mask = torch.zeros_like(hidden_state[0])
                # print(filter_mask.shape)
                filter_mask[:,j] = False
                hidden = hidden_state[0][filter_mask.to(torch.bool)]
                cell = hidden_state[1][filter_mask.to(torch.bool)]
                hidden_state = (hidden, cell)
                filter_mask = torch.zeros_like(captions)
                filter_mask[j] = False
                captions = captions[filter_mask.to(torch.bool)]
                filter_mask = torch.zeros_like(encoder_output)
                filter_mask[j] = False
                encoder_output = encoder_output[filter_mask.to(torch.bool)]
        if encoder_output.shape[0] == 0:
            break
        decoder_input = captions[:, i].unsqueeze(1)
    return predicts, captions
predicts, captions = test_sample(encoder, decoder, train_dataset, 10, 32, 128)
# print(predicts, captions)
predicts = predicts[0]
captions = captions[0]
print(predicts)
print(captions)
detokenize = tokenizer.convert_ids_to_tokens
for i in range(min(len(predicts), captions.shape[0])):
    print(detokenize(predicts[i]), detokenize(captions[i].item()))

[101, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037, 1037]
tensor([ 101, 1037, 1048, 1045, 1056, 1056, 1048, 1041, 1043, 1045, 1054, 1048,
        1039, 1051, 1058, 1041, 1054, 1041, 1040, 1045, 1050, 1052, 1037, 1045,
        1050, 1056, 1055, 1045, 1056, 1055, 1045, 1050, 1042, 1054, 1051, 1050,
        1056, 1051, 1042, 1037, 1052, 1037, 1045, 1050, 1056,  102],
       device='cuda:0')
[CLS] [CLS]
a a
a l
a i
a t
a t
a l
a e
a g
a i
a r
a l
a c
a o
a v
a e
a r
a e
a d
a i
a n
a p
a a
a i
a n
a t
a s
a i
a t
a s
a i
a n
