In [185]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
from skimage import io
from torchvision.transforms import v2
import spacy
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import torchvision.models as models
import torch.optim as optim

In [186]:
spacy_eng = spacy.load("en_core_web_sm")

In [187]:
class Vocabulary:
    def __init__(self, threshold=3):
        self.stoi = {'<PAD>' : 0, '<SOS>' : 1, '<EOS>' : 2, '<UNK>' : 3}
        self.itos = {0 : '<PAD>', 1 : '<SOS>', 2 : '<EOS>', 3 :'<UNK>'}
        self.threshold = threshold


    def __len__(self):
        return len(self.stoi)
    
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]    # returns lower case tokens separated by space

    def build_vocabulary(self, sentences):
        frequencies = {}
        idx = 4
        for sentence in sentences:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                if frequencies[word] == self.threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
                for token in tokenized_text]

In [188]:
class CustomDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform):
        self.df = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform
        self.imgs = self.df['filename']
        self.captions = self.df['caption']

        self.vocab = Vocabulary()
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, self.imgs[index])
        caption = self.captions[index]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        numericalized_caption = [self.vocab.stoi['<SOS>']]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi['<EOS>'])

        return image, torch.tensor(numericalized_caption)

In [189]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size=3, train_CNN = False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained = True, aux_logits = True)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        features = self.inception(x)[0]

        for name, param in self.inception.named_parameters():
            if 'fc.weight' in name or 'fc.bias' in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
        return self.dropout(self.relu(features))

In [190]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs        

In [191]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
    
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoder(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoder.lstm(x, states)
                output = self.decoder.linear(hiddens.squeeze(0))
                predicted = output.argmax(0)
                result_caption.append(predicted.item())
                x = self.decoder.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]


In [192]:
class Collate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        captions = [item[1] for item in batch]
        captions = pad_sequence(captions,batch_first=False, padding_value=self.pad_idx)

        return imgs, captions

In [193]:
def get_loader(root_dir, csv_path, transform, batch_size = 32, shuffle = True):
    dataset = CustomDataset(csv_path=csv_path, root_dir=root_dir, transform=transform)
    pad_idx = dataset.vocab.stoi['<PAD>']
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            collate_fn=Collate(pad_idx=pad_idx))
    return dataloader, dataset

In [194]:
transform = v2.Compose([
    v2.Resize(size=(356, 356)),  # Or Resize(antialias=True)
    v2.RandomCrop((299,299)),
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])

In [195]:
train_loader, train_set = get_loader(root_dir='A:\\Work\\Deep Learning Assignment\\Term Project\\Automatic-Image-Captioning\\Dataset\\train', 
                                         csv_path='A:\\Work\\Deep Learning Assignment\\Term Project\\Automatic-Image-Captioning\\Dataset\\train.csv', 
                                         transform=transform)

In [196]:
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([32, 3, 299, 299]) torch.Size([93, 32])


In [197]:
test_loader, test_set = get_loader(root_dir='A:\\Work\\Deep Learning Assignment\\Term Project\\Automatic-Image-Captioning\\Dataset\\test', 
                                         csv_path='A:\\Work\\Deep Learning Assignment\\Term Project\\Automatic-Image-Captioning\\Dataset\\test.csv', 
                                         transform=transform)

In [198]:
for x, y in test_loader:
    print(x.shape, y.shape)
    break

torch.Size([32, 3, 299, 299]) torch.Size([121, 32])


In [199]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [200]:
def Train():
    #Hyperparameters
    embed_size = 100
    hidden_size = 128
    vocab_size = len(train_set.vocab)
    num_layers = 1
    learning_rate = 3e-4
    num_epochs = 10

    #model load and run
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=train_set.vocab.stoi['<PAD>'])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(num_epochs):
        print("Epoch : ", epoch)
        Loss = []
        for idx, (images, captions) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            outputs = model(images, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
            Loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()
        print('Loss : ', sum(Loss)/ len(Loss))
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0

    return model

In [204]:
def print_examples(model, device, dataset):
    transform = v2.Compose([
    v2.Resize(size=(356, 356)),  # Or Resize(antialias=True)
    v2.RandomCrop((299,299)),
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])

    model.eval()
    for i in range(1, 929):
        test_img1 = transform(Image.open("A:\Work\Deep Learning Assignment\Term Project\Automatic-Image-Captioning\Dataset\\test\\test_"+ str(i) +".jpg").convert("RGB")).unsqueeze(0)
        print("Example 1 CORRECT: A large building with bars on the windows in front of it. There is people walking in front of the building. There is a street in front of the building with many cars on it. ")
        print(
        "Example 1 OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )
    model.train()

In [202]:
model  = Train()



Epoch :  0
Loss :  6.028761714530391
Epoch :  1
Loss :  4.924346913172546
Epoch :  2
Loss :  4.638622262624389
Epoch :  3
Loss :  4.4351592143820655
Epoch :  4
Loss :  4.296809873101432
Epoch :  5
Loss :  4.197027612664846
Epoch :  6
Loss :  4.117155781005348
Epoch :  7
Loss :  4.048384096369397
Epoch :  8
Loss :  3.9889526034200657
Epoch :  9
Loss :  3.9373812129377654


In [205]:
print_examples(model, device, test_set)

Example 1 CORRECT: A large building with bars on the windows in front of it. There is people walking in front of the building. There is a street in front of the building with many cars on it. 
Example 1 OUTPUT: <SOS> in suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of
Example 1 CORRECT: A large building with bars on the windows in front of it. There is people walking in front of the building. There is a street in front of the building with many cars on it. 
Example 1 OUTPUT: <SOS> in suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of . a suit person in of train of train of
Example 1 CORRECT: A large building with bars on the windows in front of it. There is people walking in front of the building. There is a street in front of the buildin