In [102]:
import os
import nltk
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision.models as models
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
import torch.optim as optim
import tqdm as tqdm 
from nltk.translate.bleu_score import corpus_bleu
import string

# Download NLTK Tokeniser
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Edward\AppData\Roaming\nltk_data...


In [2]:
nltk.download('punkt')
print(nltk.data.path)

['C:\\Users\\Edward/nltk_data', 'c:\\Users\\Edward\\Desktop\\Projects\\ML\\Automatic Image Captioning\\venv\\nltk_data', 'c:\\Users\\Edward\\Desktop\\Projects\\ML\\Automatic Image Captioning\\venv\\share\\nltk_data', 'c:\\Users\\Edward\\Desktop\\Projects\\ML\\Automatic Image Captioning\\venv\\lib\\nltk_data', 'C:\\Users\\Edward\\AppData\\Roaming\\nltk_data', 'C:\\nltk_data', 'D:\\nltk_data', 'E:\\nltk_data']


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Edward\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Edward\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [4]:
punkt_path = 'C:/Users/Edward/AppData/Roaming/nltk_data/tokenizers/punkt'
print(os.path.exists(punkt_path))
print(os.listdir(punkt_path) if os.path.exists(punkt_path) else "Punkt not found")

True
['.DS_Store', 'czech.pickle', 'danish.pickle', 'dutch.pickle', 'english.pickle', 'estonian.pickle', 'finnish.pickle', 'french.pickle', 'german.pickle', 'greek.pickle', 'italian.pickle', 'malayalam.pickle', 'norwegian.pickle', 'polish.pickle', 'portuguese.pickle', 'PY3', 'README', 'russian.pickle', 'slovene.pickle', 'spanish.pickle', 'swedish.pickle', 'turkish.pickle']


In [17]:
df = pd.read_csv('captions_csv.csv', sep=',')
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [64]:
captionsDf = df['caption']
imageNamesDf = df['image']

In [20]:
captionsDf.head()

0    A child in a pink dress is climbing up a set o...
1                A girl going into a wooden building .
2     A little girl climbing into a wooden playhouse .
3    A little girl climbing the stairs to her playh...
4    A little girl in a pink dress going into a woo...
Name: caption, dtype: object

In [65]:
imageNamesDf.head()

0    1000268201_693b08cb0e.jpg
1    1000268201_693b08cb0e.jpg
2    1000268201_693b08cb0e.jpg
3    1000268201_693b08cb0e.jpg
4    1000268201_693b08cb0e.jpg
Name: image, dtype: object

In [66]:
len(imageNamesDf) == len(captionsDf)

True

In [41]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40455 entries, 0 to 40454
Data columns (total 2 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   image    40455 non-null  object
 1   caption  40455 non-null  object
dtypes: object(2)
memory usage: 632.2+ KB


In [42]:
df.shape

(40455, 2)

# Vocabulary Class

In [43]:
df.describe()

Unnamed: 0,image,caption
count,40455,40455
unique,8091,40201
top,997722733_0cb5439472.jpg,Two dogs playing in the snow .
freq,5,7


In [44]:
df.isnull().sum()

image      0
caption    0
dtype: int64

In [None]:
class Vocabulary:
    def __init__(self, freqThreshold=5):
        #print("Vocabulary __init__ called. Ensuring itos has integer keys.") # <--- ADD THIS LINE

        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freqThreshold = freqThreshold

    def __len__(self):
        return len(self.itos)
    
    def buildVocabulary(self, sentenceList):
        frequencies = Counter()
        idx = 4 # Start new word indices from 4
        for sentence in sentenceList:
            for word in nltk.tokenize.word_tokenize(sentence.lower()):
                frequencies[word] += 1
                if frequencies[word] >= self.freqThreshold and word not in self.stoi:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    
                    idx += 1

    def numericalise(self, text):
        tokenisedText = nltk.tokenize.word_tokenize(text.lower())
        # The .get() method here handles unknown words by defaulting to <UNK>'s index
        return [self.stoi.get(token, self.stoi['<UNK>']) for token in tokenisedText]

# Custom Dataset Wrapper

In [62]:
class Flickr8kCustom(Dataset):
    def __init__(self, rootDir, captions, imageNames, transform=None, freqThreshold=5):
        self.rootDir = rootDir
        self.transform = transform
        self.vocab = Vocabulary(freqThreshold)
        
        # Load captions
        self.imgs = []
        self.captions = []

        for i in range(len(captions)):
            self.imgs.append(imageNames[i])
            self.captions.append(captions[i])
        
        #print(f'first 5 images: {self.imgs[:5]}')
        # Build vocabulary
        self.vocab.buildVocabulary(self.captions)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        imgPath = os.path.join(self.rootDir, self.imgs[idx])
        img = Image.open(imgPath).convert('RGB')
        if self.transform:
            img = self.transform(img)
        caption = self.captions[idx]  # Get one caption (randomly select or iterate)
        numericalisedCaption = [self.vocab.stoi["<SOS>"]]
        numericalisedCaption += self.vocab.numericalise(caption)  # Use first caption
        numericalisedCaption.append(self.vocab.stoi["<EOS>"])
        return img, torch.tensor(numericalisedCaption)

# Image Transformations

In [9]:
nltk.data.path.append('C:/Users/Edward/AppData/Roaming/nltk_data/tokenizers/punkt')  # Update to your path

In [67]:
# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets for train, validation, and test
trainDataset = Flickr8kCustom(
    rootDir='Images/',
    captions=captionsDf,
    imageNames=imageNamesDf,
    transform=transform
)

valDataset = Flickr8kCustom(
    rootDir='Images/',
    captions=captionsDf,
    imageNames=imageNamesDf,
    transform=transform
)


Vocabulary __init__ called. Ensuring itos has integer keys.
Vocabulary __init__ called. Ensuring itos has integer keys.


In [31]:
# Save vocabulary
with open('vocab.pkl', 'wb') as f:
    pickle.dump(trainDataset.vocab, f)

In [32]:
print(f"Training dataset size: {len(trainDataset)}")
print(f"Validation dataset size: {len(valDataset)}")

Training dataset size: 40455
Validation dataset size: 40455


In [None]:
img, caption = trainDataset[0]
print(f"Image shape: {img.shape}")
print(f"Caption indices: {caption}")
#print(len(trainDataset.vocab.itos))
print(f"Caption: {' '.join([trainDataset.vocab.itos[idx.item()] for idx in caption])}")

Image shape: torch.Size([3, 224, 224])
Caption indices: tensor([  1,   4,  28,   8,   4, 190, 148,  17,  32,  67,   4, 347,  11, 703,
          8,  24,   3, 492,   5,   2])
3005
Caption: <SOS> a child in a pink dress is climbing up a set of stairs in an <UNK> way . <EOS>


# Create Model

In [70]:
class EncoderCNN(nn.Module):
    def __init__(self, embedSize):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        #resnet = models.resnet18(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embedSize)
        self.bn = nn.BatchNorm1d(embedSize, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

In [47]:
class DecoderRNN(nn.Module):
    def __init__(self, embedSize, hiddenSize, vocabSize, numLayers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocabSize, embedSize)
        self.lstm = nn.LSTM(embedSize, hiddenSize, numLayers, batch_first=True)
        self.linear = nn.Linear(hiddenSize, vocabSize)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

In [48]:
class CNNtoRNN(nn.Module):
    def __init__(self, embedSize, hiddenSize, vocabSize, numLayers=1):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embedSize)
        self.decoder = DecoderRNN(embedSize, hiddenSize, vocabSize, numLayers)
    
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs
    
    def captionImages(self, image, vocabulary, maxLength=50):
        resultCaption = []
        with torch.no_grad():
            x = self.encoder(image).unsqueeze(1)
            states = None
            for _ in range(maxLength):
                hiddens, states = self.decoder.lstm(x, states)
                output = self.decoder.linear(hiddens.squeeze(1))
                predicted = output.argmax(1)
                resultCaption.append(predicted.item())
                x = self.decoder.embed(predicted).unsqueeze(1)
                if predicted.item() == vocabulary.stoi['<EOS>']:
                    break

        return [vocabulary.itos[idx] for idx in resultCaption]

In [50]:
def collateFn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    
    return images, targets, lengths

In [72]:
def trainModel(model, loss_fn, optimizer, vocab, device):
    numEpochs = 10
    batchSize = 32


    trainLoader = DataLoader(
        trainDataset,
        batch_size=batchSize,
        shuffle=True,
        collate_fn=collateFn
    )

    valLoader = DataLoader(
        valDataset,
        batch_size=batchSize,
        shuffle=False,
        collate_fn=collateFn
    )

    #Training Loop

    for epoch in range(numEpochs):
        model.train()
        totalTrainLoss = 0

        for images, captions, lengths in tqdm.tqdm(trainLoader, desc=f'Epoch {epoch+1}'):
            images, captions = images.to(device), captions.to(device)
            outputs = model(images, captions)
            loss = loss_fn(outputs.view(-1, len(vocab)), captions.view(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            totalTrainLoss += loss.item()

        # Validation 
        model.eval()
        totalValLoss = 0
        with torch.no_grad():
            for images, captions, lengths in valLoader:
                images, captions = images.to(device), captions.to(device)
                outputs = model(images, captions)
                loss = loss_fn(outputs.view(-1, len(vocab)), captions.view(-1))
                totalValLoss+= loss.item()

        
        print(
            f'Epoch [{epoch+1} / {numEpochs}], '
            f'TrainLoss: {totalTrainLoss / len(trainLoader):.4f}, '
            f'Val Loss: {totalValLoss/len(valLoader):.4f}'    
        )

In [61]:
with open('vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [69]:
# Initialise Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CNNtoRNN(256, 512, len(vocab), 1).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=vocab.stoi['<PAD>'])
learningRate = 3e-4
optimizer = optim.Adam(model.parameters(), lr=learningRate)



In [73]:
trainModel(model, loss_fn, optimizer, vocab, device)

Epoch 1: 100%|██████████| 1265/1265 [03:15<00:00,  6.48it/s]


Epoch [1 / 10], TrainLoss: 3.4279, Val Loss: 2.8673


Epoch 2: 100%|██████████| 1265/1265 [03:13<00:00,  6.52it/s]


Epoch [2 / 10], TrainLoss: 2.7247, Val Loss: 2.5249


Epoch 3: 100%|██████████| 1265/1265 [03:16<00:00,  6.45it/s]


Epoch [3 / 10], TrainLoss: 2.4786, Val Loss: 2.3275


Epoch 4: 100%|██████████| 1265/1265 [03:17<00:00,  6.39it/s]


Epoch [4 / 10], TrainLoss: 2.3190, Val Loss: 2.1863


Epoch 5: 100%|██████████| 1265/1265 [03:12<00:00,  6.57it/s]


Epoch [5 / 10], TrainLoss: 2.1961, Val Loss: 2.0681


Epoch 6: 100%|██████████| 1265/1265 [03:11<00:00,  6.59it/s]


Epoch [6 / 10], TrainLoss: 2.0928, Val Loss: 1.9670


Epoch 7: 100%|██████████| 1265/1265 [03:12<00:00,  6.57it/s]


Epoch [7 / 10], TrainLoss: 2.0019, Val Loss: 1.8807


Epoch 8: 100%|██████████| 1265/1265 [03:12<00:00,  6.58it/s]


Epoch [8 / 10], TrainLoss: 1.9193, Val Loss: 1.7983


Epoch 9: 100%|██████████| 1265/1265 [03:12<00:00,  6.58it/s]


Epoch [9 / 10], TrainLoss: 1.8420, Val Loss: 1.7271


Epoch 10: 100%|██████████| 1265/1265 [03:12<00:00,  6.59it/s]


Epoch [10 / 10], TrainLoss: 1.7705, Val Loss: 1.6418


In [74]:
torch.save(model.state_dict(), f'auto_caption_model_v1.pth')

In [75]:
testDataset = Flickr8kCustom(
    rootDir='Images/',
    captions=captionsDf,
    imageNames=imageNamesDf,
    transform=transform
)

testLoader = DataLoader(testDataset, batch_size=32, shuffle=False, collate_fn=collateFn)

Vocabulary __init__ called. Ensuring itos has integer keys.


In [None]:
def evaluateModel(model, dataLoader, vocab, device):
    model.eval()
    references = []
    hypotheses = []

    with torch.no_grad():
        for images, captions, length in dataLoader:
            images = images.to(device)
            for i in range(len(images)):
                image = images[i].unsqueeze(0).to(device)
                caption = model.captionImages(image, vocab)
                hypotheses.append(caption)
                ref = [vocab.itos[idx.item()] for idx in captions[i] if idx.item() not in [0, 1, 2, 3]]
                references.append([ref])
    
    blueScore = corpus_bleu(references, hypotheses)
    print(f'BLEU Score: {blueScore:.4f}')
    return blueScore

In [77]:
res = evaluateModel(model, testLoader, vocab, device)

BLEU Score: 0.0786


In [98]:
def generateCaption(imagePath, model, vocab, transform, device, maxLength=50):
    if not isinstance(imagePath, str):
        raise TypeError(f"imagePath must be a string, got {type(imagePath)}: {imagePath}")
    
    model.eval()
    image = Image.open(imagePath).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    caption = model.captionImages(image, vocab, maxLength)
    captionWords = [word for word in caption if word not in ["<SOS>", "<EOS>", "<PAD>", "<UNK>"]]
    caption = ' '.join(captionWords).strip()

    for p in string.punctuation:
        caption.replace(f" {p}", p)

    if caption:
        caption = caption[0].upper() + caption[1:]

    return caption

In [128]:
# Images/667626_18933d713e.jpg # girl in water
# Images/23445819_3a458716c1.jpg two dogs
# Images/132489044_3be606baf7.jpg man sleeping
# Images/61209225_8512e1dad5.jpg re-enactment
# Images/95728660_d47de66544.jpg man cycling in mountains
caption = generateCaption('Images/23445819_3a458716c1.jpg', model, vocab, transform, device)
print(f'Generated Caption: {caption}')

Generated Caption: Two dogs are playing in a field .


In [123]:
def wrapText(caption, font, draw, max_width):
    lines = []
    words = caption.split(' ')
    currentLine = []
    
    for word in words:
        # Test adding word to current line
        textLine = ' '.join(currentLine + [word])
        textBBox = draw.textbbox((0, 0), textLine, font=font)
        textWidth = textBBox[2] - textBBox[0]
        
        if textWidth <= max_width:
            currentLine.append(word)
        else:
            if currentLine:
                lines.append(' '.join(currentLine))
            currentLine = [word]
    
    if currentLine:
        lines.append(' '.join(currentLine))
    
    return lines

In [None]:
def addCaptionToImage(inputImagePath, outputImagePath, model, vocab, transform, device, max_length=50):

    # Ensure inputImagePath is a string
    if not isinstance(inputImagePath, str):
        raise TypeError(f"inputImagePath must be a string, got {type(inputImagePath)}: {inputImagePath}")
    
    # Debug: Print input path
    print(f"Input image path: {inputImagePath}")
    if not os.path.exists(inputImagePath):
        raise FileNotFoundError(f"Image not found: {inputImagePath}")
    
    # Generate caption
    caption = generateCaption(inputImagePath, model, vocab, transform, device, max_length)
    
    # Open the input image
    image = Image.open(inputImagePath).convert('RGB')
    
    # Load font
    try:
        font = ImageFont.truetype("arial.ttf", 20) 
    except IOError:
        font = ImageFont.load_default()
    
    
    draw = ImageDraw.Draw(image)
    
    
    maxTextWidth = image.width - 20  # 10px padding on each side
    captionLines = wrapText(caption, font, draw, maxTextWidth)
    
    
    textBBox = draw.textbbox((0, 0), "Sample", font=font)
    lineHeight = textBBox[3] - textBBox[1]
    captionHeight = len(captionLines) * lineHeight + (len(captionLines) - 1) * 5 + 20  # 5px spacing, 10px padding top/bottom
    
    # Create new image with space for caption
    newImage = Image.new('RGB', (image.width, image.height + captionHeight), color='white')
    newImage.paste(image, (0, captionHeight))
    
    
    draw = ImageDraw.Draw(newImage)
    
    
    yPosition = 10  # Top padding
    for line in captionLines:
        textBBox = draw.textbbox((0, 0), line, font=font)
        textWidth = textBBox[2] - textBBox[0]
        textX = (image.width - textWidth) // 2
        
        outlineColour = 'white'
        fillColour = 'black'
        offset = 1
        for dx, dy in [(-offset, -offset), (-offset, offset), (offset, -offset), (offset, offset)]:
            draw.text((textX + dx, yPosition + dy), line, font=font, fill=outlineColour)
        draw.text((textX, yPosition), line, font=font, fill=fillColour)
        yPosition += lineHeight + 5  # Line spacing
    
   
    newImage.save(outputImagePath)
    print(f'Saved captioned image to: {outputImagePath}')
    
    return newImage

In [131]:
inputImagePath = 'Images/667626_18933d713e.jpg'
outputImagePath = 'Output_Images/667626_18933d713e_captioned_output.jpg'

In [132]:
captionedImage = addCaptionToImage(inputImagePath, outputImagePath, model, vocab, transform, device)

Input image path: Images/667626_18933d713e.jpg
Saved captioned image to: Output_Images/667626_18933d713e_captioned_output.jpg
Caption: A young girl is in the water .
