In [51]:
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
import math


import pandas as pd
import matplotlib.pyplot as plt
import random
from config import getConfig

from PIL import Image
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter 


from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace




device = "mps" if torch.backends.mps.is_available else "cpu"
torch.device = device
config = getConfig()
config

{'tokenizer_file': 'tokeniser_1.json', 'batchSize': 8}

In [52]:
# Read the results.xl file and test

results = pd.read_excel("./flickr8k_images/captions.xls")
results.columns



Index(['captions', 'Unnamed: 1'], dtype='object')

In [None]:
mask = results[results.columns[1]].apply(lambda x: not isinstance(x,float))
results.columns = ["image","caption"]
results = results[mask]

results[results.columns[1]]=results[results.columns[1]].apply(lambda x: x.strip())

results = results.iloc[1:,:]
results.head()

In [None]:
# Testing if Data is Correctly Downloaded
def testImage():
    randNo = random.randint(0,10000)

    IMAGE_PATH = "./flickr8k_images/flickr8k_images/"+results[results.columns[0]][randNo]

    imageTransforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image = plt.imread(IMAGE_PATH)
    image = imageTransforms(image)

    plt.title(results[results.columns[1]][randNo])
    plt.imshow(torch.transpose(image,0,2).numpy())
    plt.show()

    return image

image = testImage()
image.size()

In [None]:
# Create text tokeniser
# Build Vocabulary



def get_tokeniser(ds):

    def get_all_sentences(ids):
        for text in ids["caption"]:
            if type(text) == float:
                continue
            else:
                yield text

    tokenizer_path = Path(config["tokenizer_file"])

    if tokenizer_path.exists():

        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        return tokenizer

    else:

        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
        tokenizer.save(str(tokenizer_path))

        return tokenizer
    
tokenizer = get_tokeniser(results)

print(tokenizer.encode("My name is Ashrya Shravan").ids)
print(tokenizer.get_vocab_size())

In [None]:
results.head()

In [57]:
# Create Datasets and Data loaders for images


class ImageCaptionDataset(Dataset):


    def __init__(self, ds: pd.DataFrame, transforms: torchvision.transforms, tokenizer_tgt : Tokenizer, seq_len : torch.int64 = 100) -> None:
        super().__init__()

        self.seq_len = seq_len
        self.ds = ds
        self.transforms = transforms
        self.tokenizer_tgt = tokenizer_tgt
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    
    def __getitem__(self, index):


        # Image reading and transforms
        image = plt.imread("./flickr8k_images/flickr8k_images/"+self.ds.iloc[index,:]["image"])
        encoder_input = self.transforms(image)

        # text transforms
        tgt_text = self.ds.iloc[index,:]["caption"]
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return encoder_input, decoder_input, label 
    


        
imageTransforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)





In [58]:


# Model Declaration and Image Passing to get Features from Image

class ImageCaptionCNNModel(nn.Module):

    def __init__(self, *args, **kwargs) -> None:

        super().__init__(*args, **kwargs)

        f"""
            This requires input of format [N,C,H,W] with all transforms
        """

        self.vgg16 = torchvision.models.vgg16(weights="DEFAULT")

        self.VGGLAYER = nn.Sequential(*self.vgg16.features.children())

        self.Pool = nn.AdaptiveAvgPool2d(output_size=(7, 7))


    def forward(self, image):

        image = self.VGGLAYER(image)

        image = self.Pool(image)

        return image
    


imageTransforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)
    


Image_model = ImageCaptionCNNModel()

res = Image_model(image)

res.size() 



torch.Size([512, 7, 7])

In [60]:
# Define the Transformer

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class ImageCaptioningModel(nn.Module):

    def __init__(self,d_model: torch.int64 = 512,vocab_size:torch.int64=5439,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        

        self.vocab_size = vocab_size
        self.d_model = d_model


        self.cnn_layer = ImageCaptionCNNModel()

        for param in self.cnn_layer.parameters():
            param.requires_grad = False

        self.embedding = nn.Embedding(self.vocab_size,self.d_model)
        self.positional_encoding = PositionalEncoding(self.d_model)

        self.transformer_layer = nn.Transformer(batch_first=True)
        
        self.projection_layer = nn.Sequential(
            nn.Linear(self.d_model,self.vocab_size)
        )

    
    def forward(self,image,token):

        image = self.cnn_layer(image) #[4096]
        image = image.view(image.size()[0],-1,512)

        embeddings = self.embedding(token) *  math.sqrt(self.d_model)

        image = self.transformer_layer(image,embeddings)

        image = self.projection_layer(image)



        return image
        

model = ImageCaptioningModel()


In [None]:
# Train The model

BATCH_SIZE = 8
EPOCHS = 10
global_step = 0
writer = SummaryWriter("ImageCaptioning_2")

tokenizer = get_tokeniser(results)

imageDataset = ImageCaptionDataset(results,imageTransforms,tokenizer)

imagedataLoader = DataLoader(imageDataset,BATCH_SIZE)


model = ImageCaptioningModel().to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('[PAD]'),label_smoothing=0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, eps=1e-9)


batch_iterator = imagedataLoader

for epoch in range(EPOCHS):

    model.train()

    for idx,batch in enumerate(batch_iterator):

        optimizer.zero_grad(set_to_none=True)

        proj_out = model(batch[0].to(device),batch[1].to(device)) 

        label = batch[2].to(device)

        loss = loss_fn(proj_out.view(-1, tokenizer.get_vocab_size()), label.view(-1))

        writer.add_scalar('train loss', loss.item(), global_step)
        writer.flush()

        loss.backward()

        optimizer.step()
        

        global_step += 1

        if(idx%1000 == 0):
            print(f"Completed {idx} batches at loss {loss}")

    resDir = Path('results')
    finalPath = resDir / f"naive{epoch}.pth"
    torch.save(model.state_dict(),finalPath)
    print(f"Completed {epoch} epoch with loss {loss} and saved model")


    # model.eval()

    # with torch.inference_mode():
    #     res = model(batch[0],batch[1])

    # print(res.size())
    # if(idx>=10):
    #     break




In [29]:

torch.save(model.state_dict(),finalPath)

In [None]:

resDir = Path('results')
finalPath = resDir / 'naive9.pth'
captioner = ImageCaptioningModel().to(device)

captioner.load_state_dict(torch.load(finalPath))

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image



def preprocess_image(image_path, imageTransforms):
    image = Image.open(image_path).convert('RGB')
    image = imageTransforms(image).unsqueeze(0)  
    return image

def generate_caption(model, tokenizer, image, max_length=100, device='cpu'):
    model.eval()
    with torch.no_grad():

        image = image.to(device)

        caption_tokens = [tokenizer.token_to_id('[SOS]')]
        
        for _ in range(max_length):

            input_tokens = torch.tensor(caption_tokens).unsqueeze(0).to(device)  
            
            output = model(image, input_tokens)
            print(output.shape)
            next_token_logits = output[0, -1, :]  
            print(next_token_logits.shape)
            next_token_id = next_token_logits.argmax(dim=-1).item()  
            print(next_token_id)

            caption_tokens.append(next_token_id)
            
            if next_token_id == tokenizer.token_to_id('[EOS]'):
                break

    caption = [tokenizer.id_to_token(token_id) for token_id in caption_tokens]

    print(caption)

    caption = caption[1:-1]
    
    caption = ' '.join(caption)
    
    return caption


image_path = './flickr30k_images/flickr30k_images/36979.jpg'
image = preprocess_image(image_path, imageTransforms)

caption = generate_caption(captioner, tokenizer, image, device=device)
print("Generated Caption:", caption)


