<a href="https://colab.research.google.com/github/TarnNished/deep_learning_final/blob/main/data_and_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/TarnNished/deep_learning_final/blob/main/data_and_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [160]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [161]:
import os
import json
import random
from collections import Counter
from typing import List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

import matplotlib.pyplot as plt
from tqdm import tqdm


In [162]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CONFIG = {
    "image_size": 224,
    "embedding_dim": 256,
    "hidden_dim": 512,
    "num_layers": 1,
    "batch_size": 32,
    "lr": 1e-3,
    "epochs": 20,
    "max_len": 30,
    "min_word_freq": 1
}


In [163]:
if torch.cuda.is_available():
    print("using gpu")

using gpu


In [164]:
DATA_ROOT = "/content/drive/MyDrive/caption_data"
IMAGE_DIR = os.path.join(DATA_ROOT, "Images")
CAPTIONS_FILE = os.path.join(DATA_ROOT, "captions.txt")
ARTIFACTS_DIR = "/content/drive/MyDrive/visual-storyteller/artifacts"


In [165]:
torch.backends.cudnn.benchmark = True

def load_captions(image_dir, captions_file):

    samples = []

    with open(captions_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    # remove header if exists
    if lines[0].lower().startswith("image"):
        lines = lines[1:]

    for line in lines:
        image_name, caption = line.strip().split(",", 1)
        image_path = os.path.join(image_dir, image_name)

        if os.path.exists(image_path):
            samples.append((image_path, caption.lower()))

    return samples


In [166]:
all_samples = load_captions(IMAGE_DIR, CAPTIONS_FILE)
print("Total (image, caption) pairs:", len(all_samples))


Total (image, caption) pairs: 40455


In [167]:

random.shuffle(all_samples)
train_end = int(0.8 * len(all_samples))
val_end = int(0.9 * len(all_samples))
train_data = all_samples[:train_end]
val_data = all_samples[train_end:val_end]
test_data = all_samples[val_end:]


In [168]:

SPECIAL_TOKENS = {"<pad>":0,"<bos>":1,"<eos>":2,"<unk>":3}

def build_vocab(captions, min_freq):
    counter = Counter()
    for c in captions:
        counter.update(c.split())
    vocab = dict(SPECIAL_TOKENS)
    idx = len(vocab)
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

vocab = build_vocab([c for _, c in train_data], CONFIG["min_word_freq"])
ivocab = {i:w for w,i in vocab.items()}
vocab_size = len(vocab)


In [169]:

def encode_caption(caption, vocab, max_len):
    tokens = caption.split()
    encoded = [vocab.get(w, vocab["<unk>"]) for w in tokens]
    encoded = [vocab["<bos>"]] + encoded + [vocab["<eos>"]]
    encoded = encoded[:max_len]
    encoded += [vocab["<pad>"]] * (max_len - len(encoded))
    return torch.tensor(encoded)


In [170]:

class CaptionDataset(Dataset):
    def __init__(self, data, vocab, transform):
        self.data = data
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, caption = self.data[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        cap = encode_caption(caption, self.vocab, CONFIG["max_len"])
        return img, cap


In [171]:

transform = transforms.Compose([
    transforms.Resize((CONFIG["image_size"], CONFIG["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

train_loader = DataLoader(
    CaptionDataset(train_data, vocab, transform),
    batch_size=32,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    CaptionDataset(val_data, vocab, transform),
    batch_size=32,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True
)

In [172]:

class Encoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        resnet = models.resnet18(pretrained=True)

        # Keep spatial map: (B, 512, 7, 7)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        self.conv_proj = nn.Conv2d(512, embed_dim, kernel_size=1)

    def forward(self, images):
        """
        images: (B, 3, 224, 224)
        returns: (B, 49, embed_dim)
        """
        feats = self.backbone(images)         # (B, 512, 7, 7)
        feats = self.conv_proj(feats)         # (B, E, 7, 7)

        feats = feats.flatten(2)              # (B, E, 49)
        feats = feats.permute(0, 2, 1)        # (B, 49, E)

        return feats


In [173]:
class Attention(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.encoder_att = nn.Linear(embed_dim, hidden_dim)
        self.decoder_att = nn.Linear(hidden_dim, hidden_dim)
        self.full_att = nn.Linear(hidden_dim, 1)

    def forward(self, encoder_feats, hidden_state):
        """
        encoder_feats: (B, 49, E)
        hidden_state: (B, H)
        """
        att1 = self.encoder_att(encoder_feats)           # (B, 49, H)
        att2 = self.decoder_att(hidden_state).unsqueeze(1)  # (B, 1, H)

        scores = self.full_att(torch.tanh(att1 + att2))  # (B, 49, 1)
        alpha = torch.softmax(scores, dim=1)             # (B, 49, 1)

        context = (encoder_feats * alpha).sum(dim=1)     # (B, E)
        return context, alpha


In [174]:
class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = Attention(embed_dim, hidden_dim)

        self.lstm = nn.LSTMCell(embed_dim * 2, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size)

        self.init_h = nn.Linear(embed_dim, hidden_dim)
        self.init_c = nn.Linear(embed_dim, hidden_dim)

    def init_hidden(self, encoder_feats):
        mean_feats = encoder_feats.mean(dim=1)
        h = self.init_h(mean_feats)
        c = self.init_c(mean_feats)
        return h, c

    def forward(self, encoder_feats, captions):
        """
        encoder_feats: (B, 49, E)
        captions: (B, T)
        """
        B, T = captions.shape
        embeddings = self.embedding(captions)

        h, c = self.init_hidden(encoder_feats)

        outputs = []

        for t in range(T - 1):
            context, _ = self.attention(encoder_feats, h)
            lstm_input = torch.cat([embeddings[:, t], context], dim=1)

            h, c = self.lstm(lstm_input, (h, c))
            outputs.append(self.fc(h))

        return torch.stack(outputs, dim=1)


In [175]:

encoder = Encoder(CONFIG["embedding_dim"]).to(DEVICE)
decoder = DecoderWithAttention(
    vocab_size,
    CONFIG["embedding_dim"],
    CONFIG["hidden_dim"]
).to(DEVICE)

# Freeze all CNN layers first
for p in encoder.backbone.parameters():
    p.requires_grad = False

# Unfreeze last ResNet blocks (fine-tuning)
for p in encoder.backbone[-1:].parameters():
    p.requires_grad = True

criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=1e-4
)


In [176]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(0))

CUDA available: True
Current device: 0
Device name: Tesla T4


In [177]:
def evaluate(loader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    total_loss = 0

    with torch.no_grad():
        for images, captions in loader:
            images, captions = images.to(DEVICE), captions.to(DEVICE)

            feats = encoder(images)
            outputs = decoder(feats, captions)

            loss = criterion(
                outputs.reshape(-1, vocab_size),
                captions[:, 1:].reshape(-1)
            )
            total_loss += loss.item()

    return total_loss / len(loader)




# Training loop with validation
train_losses = []
val_losses = []

for epoch in range(CONFIG["epochs"]):
    encoder.backbone.eval()
    encoder.backbone[-1:].train()
    decoder.train()

    total_loss = 0
    for images, captions in tqdm(train_loader):
        images, captions = images.to(DEVICE), captions.to(DEVICE)

        optimizer.zero_grad()

        feats = encoder(images)
        outputs = decoder(feats, captions)

        loss = criterion(
            outputs.reshape(-1, vocab_size),
            captions[:, 1:].reshape(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch+1}: Training loss = {avg_train_loss:.4f}")

    # Validation step: evaluate model on the validation set
    val_loss = evaluate(val_loader, encoder, decoder)  # Use val_loader for validation
    val_losses.append(val_loss)
    print(f"Epoch {epoch+1}: Validation loss = {val_loss:.4f}")


100%|██████████| 1012/1012 [04:31<00:00,  3.72it/s]

Epoch 1: Training loss = 4.4739





Epoch 1: Validation loss = 3.8180


100%|██████████| 1012/1012 [04:32<00:00,  3.71it/s]


Epoch 2: Training loss = 3.5740
Epoch 2: Validation loss = 3.4013


100%|██████████| 1012/1012 [04:31<00:00,  3.72it/s]


Epoch 3: Training loss = 3.2001
Epoch 3: Validation loss = 3.1608


100%|██████████| 1012/1012 [04:31<00:00,  3.73it/s]


Epoch 4: Training loss = 2.9455
Epoch 4: Validation loss = 3.0045


100%|██████████| 1012/1012 [04:33<00:00,  3.70it/s]


Epoch 5: Training loss = 2.7475
Epoch 5: Validation loss = 2.9024


100%|██████████| 1012/1012 [04:26<00:00,  3.80it/s]


Epoch 6: Training loss = 2.5833
Epoch 6: Validation loss = 2.8281


100%|██████████| 1012/1012 [04:28<00:00,  3.77it/s]


Epoch 7: Training loss = 2.4411
Epoch 7: Validation loss = 2.7790


100%|██████████| 1012/1012 [04:26<00:00,  3.80it/s]


Epoch 8: Training loss = 2.3129
Epoch 8: Validation loss = 2.7454


100%|██████████| 1012/1012 [04:25<00:00,  3.81it/s]


Epoch 9: Training loss = 2.1943
Epoch 9: Validation loss = 2.7126


100%|██████████| 1012/1012 [04:26<00:00,  3.80it/s]


Epoch 10: Training loss = 2.0844
Epoch 10: Validation loss = 2.7002


100%|██████████| 1012/1012 [04:26<00:00,  3.79it/s]


Epoch 11: Training loss = 1.9814
Epoch 11: Validation loss = 2.6992


100%|██████████| 1012/1012 [04:25<00:00,  3.81it/s]


Epoch 12: Training loss = 1.8822
Epoch 12: Validation loss = 2.6958


100%|██████████| 1012/1012 [04:31<00:00,  3.73it/s]


Epoch 13: Training loss = 1.7872
Epoch 13: Validation loss = 2.7060


100%|██████████| 1012/1012 [04:29<00:00,  3.76it/s]


Epoch 14: Training loss = 1.6979
Epoch 14: Validation loss = 2.7139


100%|██████████| 1012/1012 [04:30<00:00,  3.74it/s]


Epoch 15: Training loss = 1.6116
Epoch 15: Validation loss = 2.7259


100%|██████████| 1012/1012 [04:26<00:00,  3.79it/s]


Epoch 16: Training loss = 1.5297
Epoch 16: Validation loss = 2.7569


100%|██████████| 1012/1012 [04:25<00:00,  3.81it/s]


Epoch 17: Training loss = 1.4498
Epoch 17: Validation loss = 2.7741


100%|██████████| 1012/1012 [04:24<00:00,  3.82it/s]


Epoch 18: Training loss = 1.3731
Epoch 18: Validation loss = 2.7888


100%|██████████| 1012/1012 [04:32<00:00,  3.71it/s]


Epoch 19: Training loss = 1.3011
Epoch 19: Validation loss = 2.8200


100%|██████████| 1012/1012 [04:35<00:00,  3.67it/s]


Epoch 20: Training loss = 1.2324
Epoch 20: Validation loss = 2.8482


In [178]:

os.makedirs("artifacts", exist_ok=True)
torch.save({"encoder":encoder.state_dict(),
            "decoder":decoder.state_dict()}, "artifacts/model.pt")
json.dump(vocab, open("artifacts/vocab.json","w"))
json.dump(CONFIG, open("artifacts/config.json","w"))
print("Saved artifacts")


Saved artifacts
