Import thư viện

In [3]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T
from Data_loader import FlickrDataset,get_data_loader
from CNN import CNN

In [4]:
import matplotlib.pyplot as plt
def show_image(img, title=None):
    """Imshow for Tensor."""
    
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [5]:
data_location =  "./Flickr8k"
BATCH_SIZE = 256
# BATCH_SIZE = 6
NUM_WORKER = 4

transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])



dataset =  FlickrDataset(
    root_dir = data_location+"/Flicker8k_Dataset",
    caption_file = data_location+"./captions.txt",
    transform=transforms
)

data_loader = get_data_loader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True,
)

#vocab_size
vocab_size = len(dataset.vocab)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 0.003

In [8]:
state_dict = torch.load("./CNN.pth")
print(type(state_dict))

<class 'collections.OrderedDict'>


  state_dict = torch.load("./CNN.pth")


In [10]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        self.CNN = CNN()
        
        state_dict = torch.load("./CNN.pth", map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        self.CNN.load_state_dict(state_dict)

        for param in self.CNN.parameters():
            param.requires_grad_(False)
        
        modules = list(self.CNN.neural_net.children())[:-6]
        self.CNN = nn.Sequential(*modules)
        

    def forward(self, images):
        print("Input images shape:", images.shape)  # In kích thước đầu vào
        features = self.CNN(images)                                  
        print("Features shape after CNN:", features.shape)  # In kích thước sau khi qua CNN
        features = features.permute(0, 2, 3, 1)                           
        print("Features shape after permute:", features.shape)  # In kích thước sau khi permute
        features = features.view(features.size(0), -1, features.size(-1)) 
        print("Features shape after view:", features.shape)  # In kích thước sau khi view
        return features

In [16]:
# test encoder
# Lấy một ảnh từ tập dữ liệu Flickr8k
for idx, (image, captions) in enumerate(iter(data_loader)):
    # Chỉ lấy một ảnh và thoát khỏi vòng lặp
    image = image[0].unsqueeze(0)  # Chọn ảnh đầu tiên trong batch và thêm chiều batch
    captions = captions[0]  # Lấy caption tương ứng
    break

# Chuyển ảnh sang thiết bị GPU/CPU phù hợp
image = image.to(device)

# Tạo model EncoderCNN và chạy dự đoán
encoder = EncoderCNN().to(device)
features = encoder(image)
print(f"Feature shape: {features.shape}")

# In đặc trưng (features)
print("Extracted Features Shape:", features.shape)
print(features)




Input images shape: torch.Size([1, 3, 224, 224])
Features shape after CNN: torch.Size([1, 2048, 7, 7])
Features shape after permute: torch.Size([1, 7, 7, 2048])
Features shape after view: torch.Size([1, 49, 2048])
Feature shape: torch.Size([1, 49, 2048])
Extracted Features Shape: torch.Size([1, 49, 2048])
tensor([[[0.9515, 1.2581, 1.1516,  ..., 0.3814, 1.0883, 0.9446],
         [0.1763, 0.1483, 0.6640,  ..., 0.3639, 0.5064, 0.6546],
         [0.0000, 0.4172, 0.0000,  ..., 0.0000, 0.0000, 0.9711],
         ...,
         [0.0000, 0.6159, 0.0000,  ..., 1.4045, 0.0000, 0.9981],
         [0.0834, 1.6735, 0.3913,  ..., 0.6191, 0.0000, 0.7090],
         [0.2315, 2.1138, 0.0830,  ..., 1.1078, 0.0000, 1.3832]]],
       device='cuda:0')


In [56]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()

        # chuyển đổi chiều của encoder và decoder về cùng chiều với attention để tìm đặc trưng quan trọng đồng nhất
        self.attention_dim = attention_dim
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
        att = torch.tanh(att1 + att2.unsqueeze(1)) # (batch_size, num_pixels, attention_dim)
        print(f"att1 shape: {att1.shape}, att2 shape: {att2.unsqueeze(1).shape}")

        attention_scores = self.full_att(att) # (batch_size, num_pixels, 1)
        attention_scores = attention_scores.squeeze(2) # (batch_size, num_pixels)

        # hệ số của trọng số chú ý
        alpha = F.softmax(attention_scores, dim=1) # (batch_size, num_pixels) pixel càng quan trọng giá trị càng cao
        
        # trọng số chú ý
        attention_weights = encoder_out * alpha.unsqueeze(2) # (batch_size, num_pixels, encoder_out)
        attention_weights = attention_weights.sum(dim=1) # (batch_size, encoder_out)

        return alpha, attention_weights

In [57]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)

        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        
    
    def forward(self, features, captions):
        
        embeds = self.embedding(captions)
        
        h, c = self.init_hidden_state(features)  
       
        seq_length = len(captions[0])-1 
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,vocab=None):
        
        
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  
        
        alphas = []
        
        #starting input
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)

        
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            
            
            #store the apla score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            
            #select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            #save the generated word
            captions.append(predicted_word_idx.item())
            
            #end if <EOS detected>
            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break
            
            #send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        #covert the vocab idx to words and return sentence
        return [vocab.itos[idx] for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

In [58]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = len(dataset.vocab),
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        print(f"Feature shape: {features.shape}")
        
        outputs = self.decoder(features, captions)
        print(f"Outputs shape: {outputs.shape}")
        
        return outputs

In [59]:
model = EncoderDecoder(embed_size=300,vocab_size = len(dataset.vocab),attention_dim=256,encoder_dim=2048,decoder_dim=512).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  state_dict = torch.load("./CNN.pth", map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


In [60]:
def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':len(dataset.vocab),
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'image_caption_model.pth')

In [61]:
num_epochs = 5
print_every = 100

for epoch in range(1,num_epochs+1):   
    for idx, (image, captions) in enumerate(iter(data_loader)):
        image,captions = image.to(device),captions.to(device)

        optimizer.zero_grad()

        outputs,attentions = model(image, captions)

        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        
        loss.backward()

        optimizer.step()

        if (idx+1)%print_every == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
            
            # sinh chú thích
            model.eval()
            with torch.no_grad():
                dataiter = iter(data_loader).to(device)
                img,_ = next(dataiter).to(device)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                show_image(img[0],title=caption)
                
            model.train()
        
    save_model(model,epoch)

Input images shape: torch.Size([256, 3, 224, 224])
Features shape after CNN: torch.Size([256, 3, 224, 224])
Features shape after permute: torch.Size([256, 224, 224, 3])
Features shape after view: torch.Size([256, 50176, 3])
Feature shape: torch.Size([256, 50176, 3])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x3 and 2048x512)