In [25]:
import os
import json
import random
import wandb
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.models import resnet18
from PIL import Image
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
import nltk
import pandas as pd
import random
from tqdm import tqdm
nltk.download('punkt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Oscar\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [26]:

# Config
wandb.init(project="Group19_Lab3_flickr30k-captioning")
config = wandb.config
config.batch_size = 64
config.embed_size = 256
config.hidden_size = 512
config.num_layers = 1
config.learning_rate = 3e-4
config.num_epochs = 5
config.max_len = 30


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


In [30]:

# --- Dataset and Preprocessing ---
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"} #index to string
        self.stoi = {v: k for k, v in self.itos.items()}             #string to index

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            tokens = word_tokenize(sentence.lower())
            frequencies.update(tokens)
        for word, freq in frequencies.items():
            if freq >= self.freq_threshold:
                self.stoi[word] = idx
                self.itos[idx] = word
                idx += 1

    def numericalize(self, text):
        return [self.stoi.get(word, self.stoi["<UNK>"]) for word in word_tokenize(text.lower())]

class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform, vocab, max_len=30):
        self.max_len = max_len
        self.root_dir = root_dir
        self.transform = transform
        self.vocab = vocab

        # Load CSV and group comments by image
        df = pd.read_csv(captions_file, delimiter='|', encoding='utf-8', quotechar='"', escapechar='\\')
        df.dropna(inplace=True)
        df.columns.str.strip() 
        print(df.columns.tolist())
        # Group all captions per image
        self.captions_dict = df.groupby("image_name")[' comment'].apply(list).to_dict()
        self.image_names = list(self.captions_dict.keys())

        # Build vocab using all captions
        all_captions = [caption for captions in self.captions_dict.values() for caption in captions]
        self.vocab.build_vocabulary(all_captions)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.root_dir, image_name)
        try:
            image = Image.open(image_path).convert("RGB")
        except:
            print(f"Error loading image: {image_path}, {e}")
            raise

        if self.transform:
            image = self.transform(image)

        # Sample a random caption for this image
        captions = self.captions_dict[image_name]
        caption = random.choice(captions)

        # Convert caption to tensor of word indices
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        # Pad the caption to max_len

        if len(numericalized_caption) < self.max_len:
            numericalized_caption += [self.vocab.stoi["<PAD>"]] * (self.max_len - len(numericalized_caption))
        else:
            numericalized_caption = numericalized_caption[:self.max_len]
        
        

        return image, torch.tensor(numericalized_caption)

def collate_fn(batch, pad_idx):
    images = []
    captions = []
    for img, cap in batch:
        images.append(img)
        captions.append(cap)
    captions = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=pad_idx)
    return torch.stack(images), captions


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
img_dir = r"C:\Skola\D7047e\Lab3\flickr30k_images\flickr30k_images\flickr30k_images"
cap_dir = r"c:\Skola\D7047e\Lab3\flickr30k_images\flickr30k_images\results_clean.csv"


vocab = Vocabulary(freq_threshold=5)
dataset = FlickrDataset(root_dir=img_dir, captions_file=cap_dir, transform=transform, vocab=vocab)

for i in range(5):
    img, cap = dataset[i]
    print(img.shape, cap)

train_data, val_data, test_data = random_split(dataset, [.8, .1, .1])

train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, vocab.stoi["<PAD>"]))
val_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False, collate_fn=lambda x: collate_fn(x, vocab.stoi["<PAD>"]))
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, collate_fn=lambda x: collate_fn(x, vocab.stoi["<PAD>"]))


['image_name', ' comment_number', ' comment']
torch.Size([3, 224, 224]) tensor([ 1,  4, 28, 17, 29, 30, 23, 31, 17, 32, 19, 20,  2,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
torch.Size([3, 224, 224]) tensor([ 1, 60, 28, 53, 61, 55, 32, 62, 63, 20,  2,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
torch.Size([3, 224, 224]) tensor([ 1, 32, 77, 78, 71, 18, 73, 83, 84, 85, 20,  2,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
torch.Size([3, 224, 224]) tensor([ 1, 32, 33, 53, 32, 94, 96, 32, 93,  2,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
torch.Size([3, 224, 224]) tensor([  1,   4,  28,  23, 104,  32, 111,  20,   2,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0])


In [None]:


# --- Encoder and Decoder ---
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        self.resnet = resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        features = features.unsqueeze(1)
        inputs = torch.cat((features, embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs


In [None]:

# --- Training ---
encoder = EncoderCNN(config.embed_size).to(device)
decoder = DecoderRNN(config.embed_size, config.hidden_size, len(vocab.stoi), config.num_layers).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
params = list(decoder.parameters()) + list(encoder.resnet.fc.parameters())

optimizer = torch.optim.Adam(params, lr=config.learning_rate)


def train():
    encoder.train()
    decoder.train()
    for epoch in range(config.num_epochs):
        total_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch}")
        print("We are on this epoch:",epoch)
        for imgs, caps in loop:
            print("Bitch")
            print("on this imgs:",imgs.size())
            imgs, caps = imgs.to(device), caps.to(device)
            features = encoder(imgs)
            outputs = decoder(features, caps)
            loss = criterion(outputs.reshape(-1, outputs.size(2)), caps[:, 1:].reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            print(total_loss)
            loop.set_postfix(loss=loss.item())

        wandb.log({"epoch": epoch, "train_loss": total_loss / len(train_loader)})
        print(f"\nEpoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")
        evaluate(val_loader, encoder, decoder)


# --- Evaluation ---
def evaluate(loader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    total_bleu = 0
    loop = tqdm(loader, desc="Evaluating")

    with torch.no_grad():
        for imgs, caps in loop:
            imgs = imgs.to(device)
            features = encoder(imgs)
            output_ids = generate_caption(decoder, features, vocab)

            reference = [word_tokenize(vocab.itos[idx.item()]) for idx in caps[0] if idx.item() not in [vocab.stoi["<PAD>"], vocab.stoi["<SOS>"]]]
            candidate = [vocab.itos[idx] for idx in output_ids if idx not in [vocab.stoi["<PAD>"], vocab.stoi["<SOS>"], vocab.stoi["<EOS>"]]]

            bleu = sentence_bleu([reference], candidate, weights=(0.5, 0.5))
            total_bleu += bleu
            loop.set_postfix(bleu=bleu)

    avg_bleu = total_bleu / len(loader)
    wandb.log({"val_bleu": avg_bleu})
    print(f"\nValidation BLEU score: {avg_bleu:.4f}")
    encoder.train()
    decoder.train()


def generate_caption(decoder, feature, vocab, max_len=20):
    result = []
    input = feature.unsqueeze(1)
    states = None

    for _ in tqdm(range(max_len), desc="Generating Caption", leave=False):
        hiddens, states = decoder.lstm(input, states)
        output = decoder.linear(hiddens.squeeze(1))
        predicted = output.argmax(1)
        result.append(predicted.item())
        input = decoder.embed(predicted).unsqueeze(1)
        if predicted.item() == vocab.stoi["<EOS>"]:
            break

    return result







In [None]:
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Train loader batches: {len(train_loader)}")

Train dataset size: 25427
Train loader batches: 398


In [31]:

train()


Epoch 0:   0%|          | 0/398 [00:00<?, ?it/s]

We are on this epoch: 0
Bitch
on this imgs: torch.Size([64, 3, 224, 224])


Epoch 0:   0%|          | 0/398 [00:00<?, ?it/s]


ValueError: Expected input batch_size (1920) to match target batch_size (1856).

In [None]:
generate_caption()