<a href="https://colab.research.google.com/github/Maruf-16203091/Image-Caption-Generator/blob/main/image_caption_generate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import os
import nltk
from collections import Counter
import torchvision.models as models
import torch.nn as nn
import torch
import shutil
nltk.download('punkt')
nltk.download('punkt_tab')
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
import torch.optim as optim
shutil.rmtree('/root/nltk_data/tokenizers/punkt', ignore_errors=True)





[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [18]:
# 1. Mount drive
from google.colab import drive
drive.mount('/content/drive')

# 2. Define transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# 3. Load and process captions
captions_file = '/content/drive/My Drive/ML_Datasets/archive/captions.txt'
captions_dict = {}

with open(captions_file, 'r') as f:
    next(f)  # Skip header
    for line in f:
        line = line.strip()
        if not line:
            continue
        parts = line.split(',', 1)
        if len(parts) < 2:
            continue
        image_id, caption = parts[0].strip(), parts[1].replace('\t', ' ').strip().lower()
        captions_dict.setdefault(image_id, []).append(caption)

# 4. Build vocabulary
all_captions = []
for caps in captions_dict.values():
    all_captions.extend(caps)

tokenized_captions = [nltk.tokenize.word_tokenize(c) for c in all_captions]
word_freq = Counter()
for tokens in tokenized_captions:
    word_freq.update(tokens)

threshold = 2
vocab = [word for word, count in word_freq.items() if count >= threshold]
special_tokens = ['<pad>', '<start>', '<end>', '<unk>']
vocab = special_tokens + vocab

word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}
print(f"Vocabulary size: {len(vocab)}")

# 5. Define dataset class
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_dict, word2idx, transform=None, max_len=20):
        self.root_dir = root_dir
        self.word2idx = word2idx
        self.transform = transform
        self.max_len = max_len

        self.items = []
        for img_id, captions in captions_dict.items():
            img_path = os.path.join(root_dir, img_id)
            if os.path.exists(img_path):
                for caption in captions:
                    self.items.append((img_id, caption))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img_id, caption = self.items[idx]
        img_path = os.path.join(self.root_dir, img_id)

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        tokens = nltk.tokenize.word_tokenize(caption.lower())
        caption_idx = [self.word2idx['<start>']] + \
                      [self.word2idx.get(token, self.word2idx['<unk>']) for token in tokens] + \
                      [self.word2idx['<end>']]

        if len(caption_idx) < self.max_len:
            caption_idx += [self.word2idx['<pad>']] * (self.max_len - len(caption_idx))
        else:
            caption_idx = caption_idx[:self.max_len]

        return image, torch.tensor(caption_idx)

# 6. Set the correct image directory
image_root_dir = '/content/drive/My Drive/ML_Datasets/archive/images'

# 7. Create dataset AFTER captions_dict and vocab are ready
dataset = FlickrDataset(
    root_dir=image_root_dir,
    captions_dict=captions_dict,
    word2idx=word2idx,
    transform=transform
)

# 8. Confirm dataset size
print(f"Dataset size: {len(dataset)}")  # should be > 0

# 9. Create dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Vocabulary size: 5241
Dataset size: 40455


In [19]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()


        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)

        # Remove the last classification (fc) layer
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        # New fully connected layer for embedding
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)  # [batch, 2048, 1, 1]
        features = features.view(features.size(0), -1)  # [batch, 2048]
        features = self.linear(features)               # [batch, embed_size]
        features = self.bn(features)                   # [batch, embed_size]
        return features


In [20]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])  # Exclude <end> token
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        lstm_out, _ = self.lstm(inputs)
        outputs = self.linear(lstm_out)
        return outputs

    def sample(self, features, max_len=20):
        sampled_ids = []
        inputs = features.unsqueeze(1)
        states = None

        for _ in range(max_len):
            lstm_out, states = self.lstm(inputs, states)
            outputs = self.linear(lstm_out.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted.item())
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
            if predicted == word2idx['<end>']:
                break
        return sampled_ids


In [21]:
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)
num_epochs = 10
learning_rate = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size).to(device)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 193MB/s]


In [22]:
from torch.nn.functional import softmax

num_epochs = 10
learning_rate = 1e-3
criterion = torch.nn.CrossEntropyLoss(ignore_index=word2idx['<pad>'])
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0

    for images, captions in dataloader:
        images, captions = images.to(device), captions.to(device)

        optimizer.zero_grad()
        features = encoder(images)
        outputs = decoder(features, captions)

        # Align shapes
        outputs = outputs[:, :-1, :].reshape(-1, vocab_size)
        targets = captions[:, 1:].reshape(-1)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # ----- Accuracy Calculation -----
        _, predicted = torch.max(outputs, dim=1)
        mask = targets != word2idx['<pad>']  # Ignore <pad> positions
        correct = (predicted == targets) & mask
        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()

    accuracy = 100 * total_correct / total_tokens
    avg_loss = total_loss / len(dataloader)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")


Epoch [1/10], Loss: 4.0577, Accuracy: 27.39%
Epoch [2/10], Loss: 3.4304, Accuracy: 31.95%
Epoch [3/10], Loss: 3.1435, Accuracy: 34.18%
Epoch [4/10], Loss: 2.9102, Accuracy: 36.25%
Epoch [5/10], Loss: 2.6997, Accuracy: 38.73%
Epoch [6/10], Loss: 2.5087, Accuracy: 41.48%
Epoch [7/10], Loss: 2.3347, Accuracy: 44.31%
Epoch [8/10], Loss: 2.1832, Accuracy: 47.15%
Epoch [9/10], Loss: 2.0477, Accuracy: 49.81%
Epoch [10/10], Loss: 1.9274, Accuracy: 52.28%


In [30]:
def generate_caption(image_path):
    encoder.eval()
    decoder.eval()

    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        feature = encoder(image)
        sampled_ids = decoder.sample(feature)

    sampled_caption = []
    for word_id in sampled_ids:
        word = idx2word[word_id]
        if word == '<end>':
            break
        sampled_caption.append(word)

    sentence = ' '.join(sampled_caption)
    return sentence

# Example
print(generate_caption('/content/drive/My Drive/ML_Datasets/archive/images/72964268_d532bb8ec7.jpg'))


a in <unk>
