In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from dataset import FlickrDataset, MyCollate
from vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
import os
import pickle

In [None]:
def train_one_epoch(encoder, decoder, loader, criterion, optimizer, device):
    encoder.train()
    decoder.train()
    epoch_loss = 0

    for batch_idx, (imgs, captions) in enumerate(loader):
        imgs, captions = imgs.to(device), captions.to(device)
        optimizer.zero_grad()

        features = encoder(imgs)
        outputs = decoder(features, captions)

        # Align target with outputs
        target_captions = captions[:, 1:outputs.size(1)+1]
        if outputs.size(1) > target_captions.size(1):
            outputs = outputs[:, :target_captions.size(1), :]
        elif target_captions.size(1) > outputs.size(1):
            target_captions = target_captions[:, :outputs.size(1)]

        loss = criterion(outputs.reshape(-1, outputs.size(2)), target_captions.reshape(-1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [None]:
def validate(encoder, decoder, loader, criterion, device):
    encoder.eval()
    decoder.eval()
    val_loss = 0

    with torch.no_grad():
        for imgs, captions in loader:
            imgs, captions = imgs.to(device), captions.to(device)
            features = encoder(imgs)
            outputs = decoder(features, captions)

            target_captions = captions[:, 1:outputs.size(1)+1]
            if outputs.size(1) > target_captions.size(1):
                outputs = outputs[:, :target_captions.size(1), :]
            elif target_captions.size(1) > outputs.size(1):
                target_captions = target_captions[:, :outputs.size(1)]

            loss = criterion(outputs.reshape(-1, outputs.size(2)), target_captions.reshape(-1))
            val_loss += loss.item()

    return val_loss / len(loader)

In [None]:
def main():
    images_path = "/home/sahil_duwal/Projects/ImageCap/flickr8k/images"
    captions_file = "/home/sahil_duwal/Projects/ImageCap/flickr8k/captions.txt"
    save_dir = "checkpoints"
    os.makedirs(save_dir, exist_ok=True)


    transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    ])


    # Build vocabulary
    all_captions = []
    with open(captions_file, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i == 0 and line.strip().lower().startswith('image,caption'):
                continue
            if ',' in line:
                parts = line.strip().split(',', 1)
            elif '\t' in line:
                parts = line.strip().split('\t')
            else:
                continue
            if len(parts) == 2:
                all_captions.append(parts[1].lower())


    vocab = Vocabulary(freq_threshold=5)
    vocab.build_vocab(all_captions)


    # Save vocab
    with open(os.path.join(save_dir, "vocab.pkl"), 'wb') as f:
        pickle.dump(vocab, f)


    dataset = FlickrDataset(images_path, captions_file, vocab, transform=transform)


    val_size = max(1, int(0.1 * len(dataset)))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


    pad_idx = vocab.stoi['<PAD>']
    train_loader = DataLoader(train_dataset, batch_size=min(32, len(train_dataset)), shuffle=True, collate_fn=MyCollate(pad_idx))
    val_loader = DataLoader(val_dataset, batch_size=min(32, len(val_dataset)), shuffle=False, collate_fn=MyCollate(pad_idx))


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = EncoderCNN(embed_size=256).to(device)
    decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)


    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=3e-4)


    num_epochs = 10
    best_val_loss = float('inf')


    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        train_loss = train_one_epoch(encoder, decoder, train_loader, criterion, optimizer, device)
        val_loss = validate(encoder, decoder, val_loader, criterion, device)
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(encoder.state_dict(), os.path.join(save_dir, 'best_encoder.pth'))
            torch.save(decoder.state_dict(), os.path.join(save_dir, 'best_decoder.pth'))
            print("Saved new best model")

In [None]:
if __name__ == '__main__':
    main()