In [None]:
cd '/content/drive/MyDrive/ML_datasets/Image_captioning'

/content/drive/MyDrive/ML_datasets/Image_captioning


In [None]:
pwd

'/content/drive/MyDrive/ML_datasets/Image_captioning'

In [None]:
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -qq Flickr8k_Dataset.zip
!unzip -qq Flickr8k_text.zip
!rm Flickr8k_Dataset.zip Flickr8k_text.zip

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image


IMAGES_PATH = "/content/drive/MyDrive/ML_datasets/Image_captioning/Flicker8k_Dataset"
IMAGE_SIZE =(299, 299)
VOCAB_SIZE = 10000
SEQ_LENGTH = 25
EMBED_DIM = 512
LSTM_UNITS = 512
BATCH_SIZE = 64
EPOCHS = 30
DEVICE =torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_captions_data(filename):
    caption_mapping = {}
    text_data = []
    images_to_skip =set()
    with open(filename,'r') as f:
        for line in f:
            line = line.strip()

            img_name,caption =line.split("\t")
            img_name = img_name.split("#")[0]
            img_path= os.path.join(IMAGES_PATH, img_name)
            tokens =caption.strip().split()

            if len(tokens)<5 or len(tokens)>SEQ_LENGTH:
                images_to_skip.add(img_path)
                continue

            if img_path.endswith("jpg") and img_path not in images_to_skip:
                caption = "<start> "+caption.strip() +" <end>"
                text_data.append(caption)

                if img_path in caption_mapping:
                    caption_mapping[img_path].append(caption)
                else:
                    caption_mapping[img_path] =[caption]

    for img_path in images_to_skip:
        if img_path in caption_mapping:
            del caption_mapping[img_path]

    return caption_mapping,text_data


In [2]:
def train_val_split(caption_data, train_size=0.8):
    all_images = list(caption_data.keys())
    np.random.shuffle(all_images)
    train_len = int(len(all_images) * train_size)
    train_data = {img: caption_data[img] for img in all_images[:train_len]}
    val_data   = {img: caption_data[img] for img in all_images[train_len:]}
    return train_data, val_data

In [5]:
from collections import Counter
class Vocabulary:
    def __init__(self, max_size=VOCAB_SIZE, min_freq=1):
        self.word2idx = {"<pad>":0, "<start>":1, "<end>":2, "<unk>":3}
        self.idx2word = {0:"<pad>", 1:"<start>", 2:"<end>", 3:"<unk>"}
        self.max_size = max_size
        self.min_freq = min_freq

    def build_vocab(self, sentences):
        counter = Counter()
        for sent in sentences:
            counter.update(sent.lower().split())
        words = [w for w, f in counter.items() if f >= self.min_freq]
        words = words[:self.max_size - len(self.word2idx)]
        for i, w in enumerate(words, len(self.word2idx)):
            self.word2idx[w] = i
            self.idx2word[i] = w

    def encode(self, sentence):
        return [self.word2idx.get(w, self.word2idx["<unk>"]) for w in sentence.lower().split()]

    def decode(self, indices):
        return " ".join([self.idx2word.get(idx, "<unk>") for idx in indices])

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, image_paths, captions, vocab, transform=None):
        self.image_paths = []
        self.captions = []
        for img,caps in zip(image_paths, captions):
            for cap in caps:
                self.image_paths.append(img)
                self.captions.append(cap)
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        caption = self.vocab.encode(self.captions[idx])
        caption = torch.tensor(caption)
        return image, caption

def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images)
    lengths = [len(c) for c in captions]
    captions_padded = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=0)
    return images, captions_padded, lengths

In [None]:
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

In [None]:
class CNN_Encoder(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super().__init__()
        cnn = models.efficientnet_b0(pretrained=True)
        modules = list(cnn.children())[:-1]
        self.cnn = nn.Sequential(*modules)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(cnn.classifier[1].in_features, embed_dim)
        self.relu = nn.ReLU()
        for param in self.cnn.parameters():
            param.requires_grad = False

    def forward(self, images):
        x = self.cnn(images)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.relu(x)
        return x.unsqueeze(1)

In [None]:
class LSTM_Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features, embeddings[:, :-1, :]), 1)
        outputs, _ = self.lstm(embeddings)
        outputs = self.fc(outputs)
        return outputs

In [None]:
captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")
train_data, val_data = train_val_split(captions_mapping)


vocab = Vocabulary(max_size=VOCAB_SIZE)
vocab.build_vocab(text_data)

train_dataset = FlickrDataset(list(train_data.keys()), list(train_data.values()), vocab, transform)
val_dataset = FlickrDataset(list(val_data.keys()), list(val_data.values()), vocab, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

In [None]:
import nltk
from nltk.translate.bleu_score import corpus_bleu

nltk.download('punkt')

def evaluate_bleu(encoder, decoder, dataloader, vocab, max_len=SEQ_LENGTH):
    encoder.eval()
    decoder.eval()

    references = []
    hypotheses = []

    with torch.no_grad():
        for images, captions, lengths in dataloader:
            images = images.to(DEVICE)
            features = encoder(images)

            for i in range(images.size(0)):
                caps = captions[i]
                ref = [[vocab.idx2word[idx.item()]
                        for idx in caps if idx.item() not in [0,1,2,3]]]
                references.append(ref)


                generated = [vocab.word2idx["<start>"]]
                for _ in range(max_len):
                    cap_tensor = torch.tensor(generated).unsqueeze(0).to(DEVICE)
                    output = decoder(features[i].unsqueeze(0), cap_tensor)
                    next_word = output.argmax(2)[:,-1].item()
                    generated.append(next_word)
                    if next_word == vocab.word2idx["<end>"]:
                        break
                hyp = [vocab.idx2word[idx] for idx in generated[1:-1]]
                hypotheses.append(hyp)

    bleu_score = corpus_bleu(references, hypotheses)
    return bleu_score

In [None]:
encoder = CNN_Encoder().to(DEVICE)
decoder = LSTM_Decoder(embed_dim=EMBED_DIM, hidden_dim=LSTM_UNITS, vocab_size=VOCAB_SIZE).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=0)
params = list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

for epoch in range(EPOCHS):
    encoder.eval()
    decoder.train()

    total_loss = 0
    for images, captions, lengths in train_loader:
        images, captions = images.to(DEVICE), captions.to(DEVICE)
        features = encoder(images)
        outputs = decoder(features, captions)
        loss = criterion(outputs.view(-1, VOCAB_SIZE), captions.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss =total_loss/len(train_loader)

    bleu = evaluate_bleu(encoder,decoder,val_loader,vocab)

    print(f"Epoch {epoch+1}/{EPOCHS},Loss: {avg_loss:.4f},Validation BLEU: {bleu:.4f}")


In [None]:
def generate_caption(image_path, encoder, decoder, vocab, max_len=SEQ_LENGTH):
    encoder.eval()
    decoder.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        features = encoder(image)
        caption = [vocab.word2idx["<start>"]]
        for _ in range(max_len):
            cap_tensor = torch.tensor(caption).unsqueeze(0).to(DEVICE)
            output = decoder(features, cap_tensor)
            next_word = output.argmax(2)[:,-1].item()
            caption.append(next_word)
            if next_word == vocab.word2idx["<end>"]:
                break
    return vocab.decode(caption[1:-1])