In [1]:
# 1. Install required libraries (run once)
!pip install torch torchvision nltk pandas pillow tqdm





[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# 2. Imports and NLTK setup
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize
from tqdm import tqdm

# Download punkt tokenizer data
nltk.download('punkt_tab')



[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\visio\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            if isinstance(sentence, str):
                tokens = word_tokenize(sentence.lower())
                frequencies.update(tokens)

        for word, freq in frequencies.items():
            if freq >= self.freq_threshold:
                self.stoi[word] = idx
                self.itos[idx] = word
                idx += 1

    def numericalize(self, text):
        tokenized = word_tokenize(str(text).lower())
        return [
            self.stoi.get(token, self.stoi["<UNK>"])
            for token in tokenized
        ]


In [4]:
class Flickr8kDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None):
        self.df = pd.read_csv(captions_file, header=None, names=["image", "caption"])
        self.root_dir = root_dir
        self.transform = transform
        self.images = self.df["image"]
        self.captions = self.df["caption"]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        image_id = self.images[index]
        image = Image.open(os.path.join(self.root_dir, image_id)).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption


In [37]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        return self.fc(features)

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs
    def sample(self, features, max_len=20):
        "Generate a caption for an image feature vector."
        output_ids = []
        states = None
        inputs = features.unsqueeze(1)  # Add time dimension: (batch, 1, embed)

        for _ in range(max_len):
            hiddens, states = self.lstm(inputs, states)  # (batch, 1, hidden)
            outputs = self.linear(hiddens.squeeze(1))    # (batch, vocab_size)
            predicted = outputs.argmax(1)                # (batch,)
            output_ids.append(predicted.item())

            # Stop if <end> token is generated
            #if predicted.item() == vocab.stoi["<end>"]:
             #   break

            # Prepare input for next step
            inputs = self.embed(predicted).unsqueeze(1)

        return output_ids

class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

    def forward(self, images, captions):
        features = self.encoder(images)
        return self.decoder(features, captions)


In [6]:
def train(model, loader, criterion, optimizer, device, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for imgs, caps in tqdm(loader):
            imgs, caps = imgs.to(device), caps.to(device)
            outputs = model(imgs, caps[:, :-1])      # outputs shape: (batch, T+1, vocab)
            outputs = outputs[:, 1:, :]              # drop the image-aligned output
            targets = caps[:, 1:]                    # true next-word tokens

            # compute loss
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]),
                targets.reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(loader):.4f}")


In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

image_dir = "Flickr8k_dataset/Images"
captions_file = "captions.txt"

dataset = Flickr8kDataset(image_dir, captions_file, transform)
vocab = Vocabulary(freq_threshold=5)
vocab.build_vocabulary(dataset.df["caption"])

def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    targets = []
    for cap in captions:
        tokens = [vocab.stoi["<SOS>"]] + vocab.numericalize(cap) + [vocab.stoi["<EOS>"]]
        targets.append(torch.tensor(tokens))
    targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=vocab.stoi["<PAD>"])
    return images, targets

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ImageCaptioningModel(embed_size=256, hidden_size=512, vocab_size=len(vocab.stoi)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)




In [8]:
train(model, dataloader, criterion, optimizer, device, num_epochs=25)


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:12<00:00,  5.70s/it]


Epoch 1/25, Loss: 3.6479


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:23<00:00,  5.71s/it]


Epoch 2/25, Loss: 2.9661


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [1:58:31<00:00,  5.62s/it]


Epoch 3/25, Loss: 2.7070


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:33<00:00,  5.72s/it]


Epoch 4/25, Loss: 2.5394


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [1:57:23<00:00,  5.57s/it]


Epoch 5/25, Loss: 2.4126


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:06<00:00,  5.70s/it]


Epoch 6/25, Loss: 2.3060


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:16<00:00,  5.70s/it]


Epoch 7/25, Loss: 2.2130


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [1:54:43<00:00,  5.44s/it]


Epoch 8/25, Loss: 2.1293


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [1:53:05<00:00,  5.36s/it]


Epoch 9/25, Loss: 2.0512


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:25<00:00,  5.71s/it]


Epoch 10/25, Loss: 1.9766


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:42<00:00,  5.73s/it]


Epoch 11/25, Loss: 1.9061


100%|████████████████████████████████████████████████████████████████████████████| 1265/1265 [2:00:39<00:00,  5.72s/it]


Epoch 12/25, Loss: 1.8403


 83%|██████████████████████████████████████████████████████████████▋             | 1044/1265 [1:39:50<21:08,  5.74s/it]


KeyboardInterrupt: 

In [9]:
torch.save(model.state_dict(), 'caption_model_epoch11.pth')


In [13]:
import pickle

with open('vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)


In [38]:
# Hyperparameters (must match training setup)
embed_size = 256
hidden_size = 512
vocab_size = len(vocab.stoi)

# Recreate and load model
model = ImageCaptioningModel(embed_size, hidden_size, vocab_size).to(device)
model.load_state_dict(torch.load("caption_model_epoch11.pth"))
model.eval()


ImageCaptioningModel(
  (encoder): EncoderCNN(
    (resnet): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=T

In [39]:
from PIL import Image
import torchvision.transforms as transforms
import torch

def generate_caption(image_path, model, vocab, transform, device, max_len=50):
    model.eval()
    
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # Shape: (1, 3, H, W)

    with torch.no_grad():
        features = model.encoder(image)
        sampled_ids = model.decoder.sample(features, max_len=max_len)
    
    # Convert word IDs to words
    sampled_caption = []
    for word_id in sampled_ids:
        word = vocab.itos[word_id]
        if word == "<EOS>":
            break
        if word not in ("<SOS>", "<PAD>"):
            sampled_caption.append(word)
    
    return " ".join(sampled_caption)


In [40]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))
])


In [46]:
image_path = "dog.jpg"  # Change to your image path
caption = generate_caption(image_path, model, vocab, transform, device)
print("Generated Caption:", caption)


Generated Caption: a brown dog is running through a field of green grass .
