In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torchvision.models as models
import torch.nn as nn
import torch

# **Encoder**

In [None]:
import torchvision.models as models
import torch.nn as nn
import torch

class EncoderCNN(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove final FC layer
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, 256)  # Project features to embedding size

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)  # (batch_size, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # (batch_size, 2048)
        features = self.embed(features)  # (batch_size, 256)
        return features

# **Decoder** 

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embedding(captions[:, :-1])  # skip <end> token
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)  # prepend image features
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images)
    captions = pad_sequence(captions, batch_first=True, padding_value=vocab['<pad>'])
    return images, captions

In [None]:
from collections import Counter

class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"} # index-to-string mapping (used to decode).
        self.stoi = {v: k for k, v in self.itos.items()} # string-to-index mapping (used to encode).

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in sentence.lower().split():
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def __len__(self):
        return len(self.itos)
    
    def __getitem__(self, token):
        return self.stoi.get(token, self.stoi["<unk>"])


In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [None]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

class CustomCaptionDataset(Dataset):
    def __init__(self, csv_file, image_folder, vocab, transform=None):
        self.df = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['filename']
        caption = self.df.iloc[idx]['caption']

        # print(img_name + "," + caption)
        
        # Load and transform image
        image = Image.open(f"{self.image_folder}/{img_name}").convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Tokenize and numericalize caption
        tokens = [self.vocab['<start>']] + [self.vocab.get(word, self.vocab['<unk>']) for word in caption.lower().split()] + [self.vocab['<end>']]
        caption_tensor = torch.tensor(tokens)

        return image, caption_tensor


In [None]:
from torch.utils.data import DataLoader

# Create vocab
vocab_builder = Vocabulary(freq_threshold=2)
captions_df = pd.read_csv("/kaggle/input/caption-data/custom_captions_dataset/train.csv")
vocab_builder.build_vocabulary(captions_df['caption'].tolist())

# Assign
vocab = vocab_builder.stoi

# Load dataset
train_dataset = CustomCaptionDataset(
    csv_file="/kaggle/input/caption-data/custom_captions_dataset/train.csv",
    image_folder="/kaggle/input/caption-data/custom_captions_dataset/train",
    vocab=vocab,
    transform=transform
)

# print(type(train_dataset))
# print(train_dataset[0])

data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
# print(data_loader)

# **Training Loop**

In [None]:
import torch.optim as optim
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = EncoderCNN().to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=3e-4)

num_epochs = 1

for epoch in range(num_epochs):
    for imgs, caps in data_loader:
        imgs, caps = imgs.to(device), caps.to(device)

        features = encoder(imgs)
        outputs = decoder(features, caps)

        target = caps[:, 1:]               # remove <start> token, shape = [B, T-1]

        # Trim output to match target length
        outputs = outputs[:, :target.size(1), :]  # ensure same sequence length

        loss = criterion(outputs.reshape(-1, outputs.shape[2]), target.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

# **Saving the Model**

In [None]:
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'vocab': vocab_builder,  # saving full vocab object, not just stoi
}, 'caption_model.pth')

# **Caption Generation**

In [None]:
import os

# Recreate the model architectures
encoder = EncoderCNN().to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab_builder)).to(device)

# Load checkpoint
checkpoint = torch.load('caption_model.pth', map_location=device)

encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
vocab_builder = checkpoint['vocab']
vocab = vocab_builder.itos  # or .stoi depending on usage


def generate_caption(encoder, decoder, image_path, idx2word, transform, device, max_len=20):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        features = encoder(image)
        inputs = features.unsqueeze(1)
        states = None
        sampled_ids = []
        
        for _ in range(max_len):
            hiddens, states = decoder.lstm(inputs, states)
            outputs = decoder.linear(hiddens.squeeze(1))
            predicted = outputs.argmax(1)
            sampled_ids.append(predicted.item())
            inputs = decoder.embedding(predicted).unsqueeze(1)
    
    caption = []
    for word_id in sampled_ids:
        word = idx2word.get(word_id, "<unk>")
        if word == '<end>':
            break
        caption.append(word)
    
    return ' '.join(caption)

In [None]:
test_dir = '/kaggle/input/caption-data/custom_captions_dataset/test'

for img_name in os.listdir(test_dir):
    img_path = os.path.join(test_dir, img_name)
    caption = generate_caption(encoder, decoder, img_path, vocab_builder.itos, transform, device)
    print(f"{img_name}: {caption}")


In [None]:
# /kaggle/input/caption-data/custom_captions_dataset/val.csv
# /kaggle/input/caption-data/custom_captions_dataset/train.csv
# /kaggle/input/caption-data/custom_captions_dataset/test.csv