In [1]:
import os
import urllib.request
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from nltk.tokenize import word_tokenize
from collections import Counter

In [2]:
# ===========================
# 0️⃣ Tải và giải nén Flickr8k
# ===========================
def download_flickr8k(dataset_dir="Flickr8k"):
    os.makedirs(dataset_dir, exist_ok=True)

    # Danh sách các URL cần tải
    urls = {
        "images": "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip",
        "captions": "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
    }

    for key, url in urls.items():
        zip_path = os.path.join(dataset_dir, f"{key}.zip")
        extract_path = os.path.join(dataset_dir, key)

        if not os.path.exists(extract_path):
            print(f"📥 Downloading {key} dataset...")
            urllib.request.urlretrieve(url, zip_path)

            print(f"📂 Extracting {key} dataset...")
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(dataset_dir)

            os.remove(zip_path)  # Xóa file ZIP sau khi giải nén

    print("✅ Dataset downloaded & extracted!")

download_flickr8k()

📥 Downloading images dataset...
📂 Extracting images dataset...
📥 Downloading captions dataset...
📂 Extracting captions dataset...
✅ Dataset downloaded & extracted!


In [3]:
import os

captions_file = "Flickr8k/Flickr8k.token.txt"
image_dir = "Flickr8k/Flicker8k_Dataset"  # Thư mục chứa ảnh
captions = {}
text = []

with open(captions_file, "r") as f:
    for line in f:
        parts = line.strip().split("\t")
        img_name = parts[0].split("#")[0]

        # Nếu có ".1" ở cuối file thì loại bỏ
        if img_name.endswith(".1"):
            img_name = img_name[:-2]  # Bỏ ký tự ".1" ở cuối

        # Kiểm tra xem file có tồn tại không
        img_path = os.path.join(image_dir, img_name)
        if not os.path.exists(img_path):
            print(f"⚠️ File không tồn tại: {img_name}")  # Cảnh báo file bị thiếu
            continue  # Bỏ qua file không tồn tại

        caption = parts[1].lower()
        text.append(caption)

        if img_name not in captions:
            captions[img_name] = []
        captions[img_name].append(caption)

print("Số lượng ảnh hợp lệ:", len(captions))


⚠️ File không tồn tại: 2258277193_586949ec62.jpg
⚠️ File không tồn tại: 2258277193_586949ec62.jpg
⚠️ File không tồn tại: 2258277193_586949ec62.jpg
⚠️ File không tồn tại: 2258277193_586949ec62.jpg
⚠️ File không tồn tại: 2258277193_586949ec62.jpg
Số lượng ảnh hợp lệ: 8091


In [4]:
print(len(text))

40455


In [5]:
captions['209605542_ca9cc52e7b.jpg']

['a climber wearing a red headband is pulling himself up some grey rocks high above some green foliage .',
 'a man in a headband climbing a rock .',
 'a man with a red headband climbing a rock cliff looming over greenery .',
 'man climbing a sheet rock face .',
 'man in red headband climbing a rock']

In [6]:
text[0]

'a child in a pink dress is climbing up a set of stairs in an entry way .'

In [7]:
import os
from tokenizers import Tokenizer, pre_tokenizers, trainers, models

# Tạo tokenizer dạng word-based
tokenizer = Tokenizer(models.WordLevel(unk_token="<unk>"))

tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

trainer = trainers.WordLevelTrainer(
    vocab_size=10000,
    min_frequency=2,
    special_tokens=["<pad>", "<unk>"]
)

# Huấn luyện tokenizer
tokenizer.train_from_iterator(text, trainer)

# Lưu tokenizer
tokenizer.save("tokenizer.json")

# Load từ điển từ tokenizer
vocab = tokenizer.get_vocab()  # Trích xuất từ điển
word_to_id = lambda word: vocab.get(word, vocab["<unk>"])  # Hàm lấy ID của từ

In [8]:
from transformers import PreTrainedTokenizerFast

# Load tokenizer đã train vào PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="tokenizer.json",
    unk_token="<unk>", pad_token="<pad>"
)

In [9]:
tokenizer("i go to school")

{'input_ids': [1427, 526, 21, 750], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [10]:
# ===========================
# 1️⃣ Load Flickr8k dataset
# ===========================
class Flickr8kDataset(Dataset):
    def __init__(self, img_dir, captions, transform=None):
        self.img_dir = img_dir
        self.transform = transform

        # Load captions
        self.captions = captions

        self.img_names = list(self.captions.keys())

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        caption = np.random.choice(self.captions[img_name])  # Chọn caption ngẫu nhiên
        encoded_caption = tokenizer(caption, padding="max_length",
                                    truncation=True, max_length=20, return_tensors="pt")['input_ids'][0]
        return {
            'image': image,
            'caption': encoded_caption
        }

In [11]:
transform = transforms.Compose([
    transforms.Resize((8, 8)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = Flickr8kDataset(
    img_dir="/content/Flickr8k/Flicker8k_Dataset",
    captions=captions,
    transform=transform
)

In [12]:
sample = next(iter(dataset))

sample

{'image': tensor([[[-0.0824, -0.3412, -0.2471, -0.2706, -0.3647, -0.4196, -0.2549,
           -0.5294],
          [-0.2863, -0.5137, -0.1451, -0.2078, -0.4118, -0.4431, -0.2157,
           -0.5373],
          [-0.0902, -0.1765, -0.1608, -0.2471, -0.3490, -0.3255, -0.1059,
           -0.4667],
          [ 0.4510,  0.1529, -0.0824, -0.2471, -0.2863, -0.3333, -0.0902,
           -0.2235],
          [ 0.2157, -0.1059, -0.1529, -0.2235, -0.2392, -0.2627, -0.2314,
           -0.3490],
          [-0.5843, -0.7333, -0.5373, -0.3412, -0.2706, -0.2078, -0.0667,
           -0.3098],
          [-0.2549, -0.4431, -0.4902, -0.4824, -0.6863, -0.5686, -0.3176,
           -0.3333],
          [ 0.3255, -0.0353, -0.3255, -0.2863,  0.0196, -0.4118, -0.5529,
           -0.1216]],
 
         [[-0.0353, -0.4275, -0.3490, -0.3882, -0.5373, -0.5373, -0.3333,
           -0.5137],
          [-0.3333, -0.5922, -0.3255, -0.3647, -0.4980, -0.4902, -0.2784,
           -0.4824],
          [-0.2627, -0.3098, -0.3961, 

In [13]:
sample['image'].shape

torch.Size([3, 8, 8])

In [14]:
sample['caption'].shape

torch.Size([20])

In [15]:
# ===========================
# 2️⃣ Word Embeddings + LSTM
# ===========================
class embedding_text(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(embedding_text, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

    def forward(self, captions):
        embeds = self.embedding(captions)
        _, (hidden, _) = self.lstm(embeds)
        return hidden[-1]

# ===========================
# 3️⃣ DCGAN Model
# ===========================

class Generator(nn.Module):
    def __init__(self, latent_dim, embed_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, 128 * 4 * 4),
            #       latent_dim (64) + embed_dim (256)
            # Input:(batch_size,320) - Output:(batch_size,512)
            nn.ReLU(True),

            nn.Unflatten(1, (128, 4, 4)),
            # Input:(batch_size,512) - Output:(batch_size,128,4,4)

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            # Input:(batch_size,128,4,4) - Output:(batch_size,64, 8, 8)
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 3, stride=1, padding=1),
            # Input:(batch_size,64, 8, 8) - Output:(batch_size, 3, 8, 8)
            nn.Tanh()
        )

    def forward(self, noise, caption_embed):
        x = torch.cat((noise, caption_embed), dim=1)
        #  (batch_size,64) + (batch_size,256) = (batch_size,320)
        return self.model(x)




class Discriminator(nn.Module):
    def __init__(self, embed_dim):
        super(Discriminator, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            # Input:(batch_size,3,8,8) - Output:(batch_size,64,4,4)
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64 ,128 ,3 , stride=2, padding=1),
            # Input:(batch_size,64,4,4) - Output:(batch_size,128,2,2)
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten()
            # Input:(batch_size,128,2,2) - Output:(batch_size,512)
        )

        self.fc = nn.Linear(512 + embed_dim, 1)
        # Input:(batch_size,512) - Output:(batch_size,1)

    def forward(self, img, caption_embed):
        img_features = self.cnn(img)
        x = torch.cat((img_features, caption_embed), dim=1)
        #  (batch_size,512) + (batch_size,256) = (batch_size,768)
        return torch.sigmoid(self.fc(x))


In [16]:
import nltk

nltk.download('punkt_tab')

len(tokenizer)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


5167

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Model
generator = Generator(64, 256).to(device)
discriminator = Discriminator(256).to(device)
embedding_text = embedding_text(len(tokenizer), 256, 256).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [18]:
embedding_text

embedding_text(
  (embedding): Embedding(5167, 256)
  (lstm): LSTM(256, 256, batch_first=True)
)

In [19]:
generator

Generator(
  (model): Sequential(
    (0): Linear(in_features=320, out_features=2048, bias=True)
    (1): ReLU(inplace=True)
    (2): Unflatten(dim=1, unflattened_size=(128, 4, 4))
    (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Tanh()
  )
)

In [20]:
discriminator

Discriminator(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Flatten(start_dim=1, end_dim=-1)
  )
  (fc): Linear(in_features=768, out_features=1, bias=True)
)

In [21]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

batch_sample = next(iter(dataloader))

batch_sample

{'image': tensor([[[[ 0.4902,  0.6235,  0.6941,  ...,  0.5843,  0.4745,  0.4431],
           [ 0.4980,  0.6000,  0.6627,  ...,  0.5294,  0.5137,  0.4824],
           [ 0.3961,  0.2471,  0.5451,  ...,  0.7255,  0.7255,  0.6941],
           ...,
           [-0.9294, -0.8902, -0.1843,  ...,  0.0510,  0.1451,  0.1294],
           [-0.9059, -0.9294, -0.6157,  ..., -0.4745,  0.1059,  0.0667],
           [-0.6157, -0.4980, -0.6471,  ..., -0.2078,  0.2000,  0.0745]],
 
          [[ 0.8431,  0.9059,  0.9373,  ...,  0.9137,  0.8588,  0.8118],
           [ 0.8275,  0.8824,  0.9059,  ...,  0.8510,  0.8196,  0.7804],
           [ 0.6392,  0.4353,  0.6863,  ...,  0.8196,  0.8118,  0.7804],
           ...,
           [-0.8745, -0.8510, -0.2471,  ...,  0.0039,  0.1373,  0.1216],
           [-0.9216, -0.9059, -0.6000,  ..., -0.4745,  0.1137,  0.0667],
           [-0.8431, -0.5765, -0.6471,  ..., -0.2706,  0.1529,  0.0588]],
 
          [[ 0.9608,  0.9765,  0.9843,  ...,  0.9922,  0.9922,  0.9922],
    

In [22]:
batch_sample['image'].shape

torch.Size([128, 3, 8, 8])

In [23]:
batch_sample['caption'].shape

torch.Size([128, 20])

In [24]:
caption_embeddings = embedding_text(batch_sample['caption'].to(device))

caption_embeddings.shape

torch.Size([128, 256])

In [25]:
noise = torch.randn(batch_sample['image'].size(0), 64, device=device)
fake_images = generator(noise, caption_embeddings)

fake_images.shape

torch.Size([128, 3, 8, 8])

In [26]:
for epoch in range(5):
    for batch in dataloader:
        # load tensor images
        images = batch['image'].to(device)
        # load one-hot vector tokenizer
        captions = batch['caption'].to(device)

        caption_embeddings = embedding_text(captions)
        #  (batch_size, 256)
        noise = torch.randn(images.size(0), 64, device=device)
        #  (batch_size, 64)

        fake_images = generator(noise, caption_embeddings)
        #  (batch_size, 3, 8, 8)
        real_labels = torch.ones(images.size(0), 1, device=device)
        #  (batch_size, 1)
        fake_labels = torch.zeros(images.size(0), 1, device=device)
        #  (batch_size, 1)

        real_loss = criterion(discriminator(images, caption_embeddings), real_labels)
        fake_loss = criterion(discriminator(fake_images.detach(), caption_embeddings), fake_labels)
        '''    Detaches fake_images from the computation graph so that gradients
        do not flow back to the generator during Discriminator training.     '''

        d_loss = real_loss + fake_loss
        optimizer_D.zero_grad()
        d_loss.backward(retain_graph=True)
        #  retain_graph=True
        optimizer_D.step()

        g_loss = criterion(discriminator(fake_images, caption_embeddings), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/5], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")


Epoch [1/5], D Loss: 1.3159996271133423, G Loss: 0.9297963380813599
Epoch [2/5], D Loss: 1.241675853729248, G Loss: 0.8472822904586792
Epoch [3/5], D Loss: 1.1022919416427612, G Loss: 1.2151697874069214
Epoch [4/5], D Loss: 1.2709544897079468, G Loss: 0.8273096680641174
Epoch [5/5], D Loss: 1.3431909084320068, G Loss: 0.8363375067710876
