In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install kaggle

In [None]:
from google.colab import files
import random

In [None]:
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
# !wget -O dataset.zip "https://www.kaggle.com/datasets/adityajn105/flickr8k/download?datasetVersionNumber=1"

!kaggle datasets download -d "adityajn105/flickr8k"

In [None]:
!unzip flickr8k.zip;

In [None]:
import os  # when loading file paths
import pandas as pd  # for lookup in annotation file
import spacy  # for tokenizer
import torch
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image  # Load img
import torchvision.transforms as transforms
import string
import torch
import torch.nn as nn
import statistics
import torchvision.models as models
import torch.optim as optim
import tqdm

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((299,299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [None]:
spacy_eng = spacy.load("en_core_web_sm")

In [None]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod

    def tokenizer_eng(text):

      rem_punct = str.maketrans('', '', string.punctuation)
      tokens=[tok.text.lower() for tok in spacy_eng.tokenizer(text)]
      tokens = [tok.translate(rem_punct) for tok in tokens]
      tokens = [tok for tok in tokens if len(tok) > 1]
      # remove numeric values
      tokens = [tok for tok in tokens if tok.isalpha()]

      return tokens


    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
        print('lenth of vacbulary',len(self.stoi.keys()),len(self.itos.keys()))


    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)

In [None]:
dataset = FlickrDataset('/content/Images', '/content/captions.txt', transform=transform)

lenth of vacbulary 2970 2970


In [None]:
# dataset[0]

In [None]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        return imgs, targets


In [None]:
pad_idx = dataset.vocab.stoi["<PAD>"]

loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx),
)




In [None]:
import pickle


with open("vocab_itos.txt", "wb") as file:
    pickle.dump(dataset.vocab.itos, file)


In [None]:
dataset.vocab.itos

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=True)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        # self.times = []
        self.dropout = nn.Dropout(0.2)

    def forward(self, images):
        features = self.inception(images)
        
        return self.dropout(self.relu(features[0]))


In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


In [None]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)

            print('shape of ecoded image',x.shape)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                print('predicted-----',output)
                print('predicted shape-----',output.shape)

                predicted = output.argmax()
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break
                    
        return [vocabulary.itos[idx] for idx in result_caption]

In [None]:
dataset.vocab

<__main__.Vocabulary at 0x7fa4a2d430d0>

In [None]:
len(dataset.vocab)

2970

In [None]:
embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs = 100

In [None]:
def train(train_loader,embed_size,hidden_size,vocab_size,num_layers):

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



    # initialize model, loss etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.encoderCNN.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    model.train()

    epoch_loss =float('inf')

    for epoch in range(num_epochs):
        # Uncomment the line below to see a couple of test cases
        # print_examples(model, device, dataset)

        for idx, (imgs, captions) in tqdm.tqdm(
            enumerate(train_loader), total=len(train_loader), leave=False
        ):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )
        
            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()

        print(f'Loss: {epoch:.4f} Acc: {loss:.4f}')

        if loss<epoch_loss:
          epoch_loss = loss
          torch.save(model.state_dict(), '/content/drive/MyDrive/image_caption/best.pth')


    torch.save(model.state_dict(), '/content/drive/MyDrive/image_caption/last.pth')


In [None]:
train(loader,embed_size,hidden_size,vocab_size,num_layers)

#Load and Test

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

In [None]:
model.to(device)

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/image_caption/last.pth'))
# model.eval()

<All keys matched successfully>

In [None]:
def print_examples(model, device, dataset,img_path):
    transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()
    test_img1 = transform(Image.open(img_path).convert("RGB")).unsqueeze(
        0
    )

    print("Example 1 CORRECT: Dog on a beach by the ocean",test_img1.shape)
    print(
        "Example 1 OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )


In [None]:
img_path='/content/1024138940_f1fefbdce1.jpg'

In [None]:
a=torch.tensor([-3.9508, 39.9415, -2.7044]).argmax()
a

tensor(1)

In [None]:
print_examples(model,device,dataset,img_path)