Dùng tạm mô hình tự code để cho nó đúng đầu vào đầu ra đã

In [None]:
import numpy as np
from numpy import array
import pandas as pd
import matplotlib.pyplot as plt
import string
import os
from PIL import Image
import glob
from pickle import dump, load

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
BATCH_SIZE = 256 # 3.8GB VRAM f32 (3.6 Dedicated + 0.2 Shared)
device = 'cuda'


In [2]:
image_transforms = transforms.Compose([
    transforms.Resize((324, 324)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Dùng cho ImageNet
])

In [3]:
import util
data_folder_path = "Flickr8k/Flicker8k_Dataset"
train_dataset = util.Flickr8kDataset(
    data_folder_path,
    "Data_bert/train_set_bert.pkl",
    image_transforms,
    device=device
)
trainloader = train_dataset.get_dataloader(batch_size=BATCH_SIZE, num_workers=2, shuffle=False)
test_dataset = util.Flickr8kDataset(
    data_folder_path,
    "Data_bert/test_set_bert.pkl",
    image_transforms,
    device=device
)
testloader = test_dataset.get_dataloader(batch_size=BATCH_SIZE, num_workers=2, shuffle=False)
print(len(train_dataset)) # :)) sao có tận 30k ids ảnh trong train_set_bert.pkl, một ảnh có nhiều caption à ?
print(len(test_dataset)) 
sample_image, sample_caption = test_dataset[0]
print(sample_image.shape, sample_caption.shape)

6000
1000
30000
5000
torch.Size([3, 324, 324]) torch.Size([46])


In [4]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        self.CNN = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=3, stride=3), # 324 * 324
            nn.LeakyReLU(0.1),
            nn.Conv2d(8, 16, kernel_size=3, stride=3), # 108 * 108
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 32, kernel_size=3, stride=1), # 36 * 36
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=3, stride=3), # 12 * 12
            nn.Conv2d(32, 64, kernel_size=3, stride=1), # 10 * 10
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=3, stride=1), # 8* 8 ?? 7 * 7
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, out_features=512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 256)
        )
    def forward(self, images: torch.Tensor):
        return self.CNN(images)
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, input_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.relu = nn.ReLU()
    def forward(self, features: torch.Tensor, captions: torch.Tensor, hiden_state: tuple[torch.Tensor, torch.Tensor]): 
        # features : image_features : [bsz, embed]
        # captions : [bsz, seq]
        # hidden : [1, bsz, embed]
        # print("Input dim")
        # print(features.shape)
        # print(captions.shape)
        # print(hiden_state[0].shape)
        embeddings = self.embed(captions) # [bsz, seq, embed]
        features = features.unsqueeze(1).expand(-1, embeddings.shape[1], -1) # [bsz, seq, embed]
        combined = torch.cat((embeddings, features), dim=2) # [bsz, seq, embed*2]
        # print(combined.shape, hiden_state[0].shape, hiden_state[1].shape)
        output, hidden = self.lstm(combined, hiden_state) # [bsz, seq, hid]
        output = self.relu(output)
        output = self.linear(output) #[batch_size, seq_len, vocab_size]
        # print("End")

        return output, hidden

In [5]:
# #Test Encoder
# encoder = EncoderCNN()
# encoder.eval()
# output_en = encoder(train_dataset[0][0].unsqueeze(0))
# print(output_en.shape)
# print(tokenizer.convert_ids_to_tokens(101))

In [6]:
# #test decoder
# vocab_size = tokenizer.vocab_size
# embed_size = 16
# num_layers = 1
# hidden_size = 128
# hiden = torch.zeros(1, 1, hidden_size)
# cell = torch.zeros(1, 1, hidden_size)
# hiden_state = (hiden, cell)
# hiden_state[0].shape
# decoder = DecoderRNN(embed_size, vocab_size, hidden_size, 256, num_layers)
# decoder.eval()
# state = torch.tensor([101, 102, 103]).unsqueeze(0)
# # print(output_en.shape)
# # print(state.shape)
# # print(hiden_state[0].shape, hiden_state[1].shape)
# outputs, hidden = decoder(output_en, state, hiden_state)
# # print(outputs.shape)
# # print(hidden[0].shape)
# # print(outputs.shape)


In [7]:
def train_single_batch(
        encoder: nn.Module,
        decoder: nn.Module,
        images: torch.Tensor,
        captions: torch.Tensor,
        encoder_optimizer: torch.optim.Optimizer,
        decoder_optimizer: torch.optim.Optimizer,
        criterion: callable,
        seq_length: int,
        hidden_size: int
):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    encoder.train()
    decoder.train()

    end_seq = torch.tensor([102])
    bsz = images.shape[0]
    encoder_output: torch.Tensor = encoder(images) # [bsz, 256]
    hidden = torch.zeros(1, bsz, hidden_size).to(device)
    cell = torch.zeros(1, bsz, hidden_size).to(device)
    hidden_state = (hidden, cell)
    total_loss = 0
    predicts = []
    decoder_input = torch.tensor([101]).expand(hidden.shape[1]).unsqueeze(1).to(device)
    for i in range(1, seq_length):
        # print("Start")
        # print(encoder_output.shape)
        # print(decoder_input.shape)
        # print(hidden_state[0].shape)
        decoder_output, hidden_state = decoder(encoder_output, decoder_input, hidden_state)
        decoder_output: torch.Tensor
        hidden_state: tuple[torch.Tensor, torch.Tensor]
        # print(decoder_output.shape)
        # print(hidden_state[0].shape)
        loss: torch.Tensor = criterion(decoder_output.squeeze(1), captions[:, i])
        loss.backward(retain_graph=True) # Acculumate grad
        total_loss += loss.item()
        predicted = decoder_output.argmax(dim=2)
        # print(predicted.shape)
        for j in range(predicted.shape[0]-1, -1, -1):
            if predicted[j].item() == end_seq.item():
                filter_mask = torch.zeros_like(hidden_state[0])
                filter_mask[j] = False
                hidden_state[0] = hidden_state[0][filter_mask]
                hidden_state[1] = hidden_state[1][filter_mask]
                filter_mask = torch.zeros_like(captions)
                filter_mask[j] = False
                captions = captions[filter_mask]
                filter_mask = torch.zeros_like(encoder_output)
                filter_mask[j] = False
                encoder_output = encoder_output[filter_mask]
        if encoder_output.shape[0] == 0:
            break
        decoder_input = captions[:, i].unsqueeze(1)
    decoder_optimizer.step()
    encoder_optimizer.step()
    return total_loss

In [8]:
encoder = EncoderCNN()
decoder = DecoderRNN(
    embed_size=16,
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,
    input_size=256,
    num_layers=1
)
encoder.to(device)
decoder.to(device)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)
lossf = nn.CrossEntropyLoss(ignore_index=0)
i = 0
total_loss = 0
for images, captions in trainloader:
    loss = train_single_batch(
        encoder=encoder,
        decoder=decoder,
        images=images,
        captions=captions,
        encoder_optimizer=encoder_optimizer,
        decoder_optimizer=decoder_optimizer,
        criterion=lossf,
        seq_length=32,
        hidden_size=128
    )
    total_loss += loss
    i+=1
    print(f"{loss:.5f} | {i}/{len(trainloader)}")

nan | 1/118
nan | 2/118
319.59888 | 3/118
319.06257 | 4/118
317.68958 | 5/118
316.02462 | 6/118
310.33024 | 7/118
304.15399 | 8/118
299.97501 | 9/118
290.47260 | 10/118
282.54886 | 11/118
275.94316 | 12/118
271.06742 | 13/118
265.59258 | 14/118
261.00267 | 15/118
256.44065 | 16/118
251.41568 | 17/118
247.18878 | 18/118
241.80800 | 19/118
237.06761 | 20/118
231.75127 | 21/118
227.21666 | 22/118
222.62193 | 23/118
217.90091 | 24/118
212.78600 | 25/118
208.03147 | 26/118
203.09616 | 27/118
198.04580 | 28/118
193.55909 | 29/118
188.89791 | 30/118
183.91577 | 31/118
179.36756 | 32/118
174.40760 | 33/118
169.84485 | 34/118
165.42201 | 35/118
160.90107 | 36/118
156.90531 | 37/118
152.87543 | 38/118
148.69896 | 39/118
144.85155 | 40/118
141.03374 | 41/118
137.75021 | 42/118
134.26429 | 43/118
130.65340 | 44/118
128.16868 | 45/118
124.79240 | 46/118
122.57115 | 47/118
119.82994 | 48/118
117.34057 | 49/118
115.77596 | 50/118
113.39164 | 51/118
112.13243 | 52/118
110.22292 | 53/118
108.74863 | 54