**Ideas -**
1. Find/Use OCR for non-comp part<br>
2. Rotate all images thrice to create a larger training dataset.<br>
3. Convert image to 0-1 grayscale.<br>
4. Use attention based RNN for caption generation.<br>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from skimage import io, transform

import matplotlib.pyplot as plt # for plotting
import numpy as np
import os
torch.cuda.empty_cache()

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
%cd drive/MyDrive/

[Errno 2] No such file or directory: 'drive/MyDrive/'
/


In [None]:
!unzip train_data.zip

In [2]:
%cd train_data
!ls | wc -l
%cd ..

[Errno 2] No such file or directory: 'train_data'
/content
2
/


### Image Transforms

In [2]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, image):
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        return img


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return torch.tensor(image)


# IMAGE_RESIZE = (256, 256)
IMAGE_RESIZE = (224, 224)
# Sequentially compose the transforms
img_transform = transforms.Compose([Rescale(IMAGE_RESIZE), ToTensor()])


### Captions Preprocessing

In [3]:
class CaptionsPreprocessing:
    """Preprocess the captions, generate vocabulary and convert words to tensor tokens

    Args:
        captions_file_path (string): captions tsv file path
    """
    def __init__(self, captions_file_path):
        self.captions_file_path = captions_file_path

        # max caption length
        self.maxLen = 0

        # Read raw captions
        self.raw_captions_dict = self.read_raw_captions()

        # Preprocess captions
        self.captions_dict = self.process_captions()

        # Create vocabulary
        self.vocab = self.generate_vocabulary()


    def read_raw_captions(self):
        """
        Returns:
            Dictionary with raw captions list keyed by image ids (integers)
        """

        captions_dict = {}
        with open(self.captions_file_path, 'r', encoding='utf-8') as f:
            for img_caption_line in f.readlines():
                img_captions = img_caption_line.strip().split('\t')
                captions_dict[img_captions[0]] = img_captions[1]
                self.maxLen = max(self.maxLen, len(img_captions[1].split()) + 2)

        return captions_dict

    def process_captions(self):
        """
        Use this function to generate dictionary and other preprocessing on captions
        """

        raw_captions_dict = self.raw_captions_dict

        # Do the preprocessing here
        captions_dict = {}
        # add START and END token
        for k, v in raw_captions_dict.items():
            captions_dict[k] = '[START] ' + v + ' [END]'

        return captions_dict

    def generate_vocabulary(self):
        """
        Use this function to generate dictionary and other preprocessing on captions
        """

        captions_dict = self.captions_dict

        # Generate the vocabulary
        vocab = {'[PAD]': 0}
        idx = 1
        for caption in captions_dict.values():
            for word in caption.split():
                if word in ['[START]', '[END]']:
                    continue
                if word not in vocab:
                    vocab[word] = idx
                    idx += 1
        vocab['[START]'] = idx
        vocab['[END]'] = idx + 1

        return vocab

    def captions_transform(self, img_caption):
        """
        Use this function to generate tensor tokens for the text captions
        Args:
            img_caption: caption for a particular image
        """
        vocab = self.vocab

        # Generate tensors
        tokens = [vocab[word] for word in img_caption.split()]
        length = len(tokens)
        tokens.extend([0 for _ in range(len(tokens), self.maxLen)])
        return torch.tensor(tokens), length

# Set the captions tsv file path
CAPTIONS_FILE_PATH = 'Train_text.tsv'
captions_preprocessing_obj = CaptionsPreprocessing(CAPTIONS_FILE_PATH)

### Dataset Class

In [4]:
class ImageCaptionsDataset(Dataset):

    def __init__(self, img_dir, captions_dict, img_transform=None, captions_transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            captions_dict: Dictionary with captions list keyed by image paths (strings)
            img_transform (callable, optional): Optional transform to be applied
                on the image sample.

            captions_transform: (callable, optional): Optional transform to be applied
                on the caption sample (list).
        """
        self.img_dir = img_dir
        self.captions_dict = captions_dict
        self.img_transform = img_transform
        self.captions_transform = captions_transform

        self.image_ids = list(captions_dict.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_name = self.image_ids[idx]
        image = io.imread(img_name)
        image_90 = transform.rotate(image, 90)
        image_180 = transform.rotate(image, 180)
        image_270 = transform.rotate(image, 270)
        captions = self.captions_dict[img_name]

        if self.img_transform:
            image = self.img_transform(image)
            image_90 = self.img_transform(image_90)
            image_180 = self.img_transform(image_180)
            image_270 = self.img_transform(image_270)

        if self.captions_transform:
            captions, length = self.captions_transform(captions)

        sample = {'image': image, 'image_90': image_90, 'image_180': image_180, 'image_270': image_270, 'captions': captions, 'lengths': length}

        return sample

### Model Architecture

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # CNN architecture
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            # batch 1
            nn.Conv2d(64, 64, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(256, 64, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(256, 64, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(64, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            
            # batch 2
            nn.Conv2d(256, 128, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 512, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(512, 128, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 512, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(512, 128, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 512, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(512, 128, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(128, 512, kernel_size=1),
            nn.ReLU(inplace = True),

            # batch 3
            nn.Conv2d(512, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),
            
            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(1024, 256, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(256, 1024, kernel_size=1),
            nn.ReLU(inplace = True),

            # batch 4
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 2048, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(2048, 512, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 2048, kernel_size=1),
            nn.ReLU(inplace = True),

            nn.Conv2d(2048, 512, kernel_size=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(512, 2048, kernel_size=1),
            nn.ReLU(inplace = True),
        )

    def forward(self, image_batch):
        # Forward Propogation
        encoded_output = self.cnn(image_batch)
        encoded_output = encoded_output.permute(0, 2, 3, 1) 
        return encoded_output


In [5]:
class Encoder(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.train_CNN = train_CNN
        self.model = models.inception_v3(pretrained=True, aux_logits=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5, inplace=True)

    def forward(self, image):
        features = self.model(image)
        for name, param in self.model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = self.train_CNN
        return self.dropout(self.relu(features))

In [6]:
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(p=0.5, inplace=True)
        
    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) # Concatenate image enocded features with embedded captions
        packed_seq = torch.nn.utils.rnn.pack_padded_sequence(embeddings, lengths, batch_first=True)
        hiddens, _ = self.lstm(packed_seq)
        outputs = self.dropout(self.linear(hiddens[0]))  
        return outputs


In [69]:
class ImageCaptionsNet(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, train_CNN=False):
        super(ImageCaptionsNet, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.encoder = Encoder(embed_size, train_CNN)
        self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)


    def forward(self, image, captions, lengths):
        features = self.encoder(image)
        lengths, seq = lengths.sort(descending=True)
        image = image[seq]
        captions = captions[seq]
        outputs = self.decoder(features, captions, lengths)
        return outputs
    
    def caption_image_greedy(self, image, vocabulary, max_length=50):
        result_caption = []
        with torch.no_grad():
            x = self.encoder(image).unsqueeze(1)
            states = None
            for _ in range(max_length):
                hiddens, states = self.decoder.lstm(x, states)
                outputs = self.decoder.linear(hiddens.squeeze(1))
                predicted = outputs.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoder.embed(predicted).unsqueeze(1)
                if vocabulary[predicted.item()] == '[END]':
                    break
        return [vocabulary[idx] for idx in result_caption]
    
    def caption_image_beam_search(self, image, vocabulary, beam_size=10, max_length=50):
        with torch.no_grad():
            x = self.encoder(image).unsqueeze(1)
            states = None
            hiddens, states = self.decoder.lstm(x, states)
            outputs = self.decoder.linear(hiddens.squeeze(1))
            prob_outputs = F.log_softmax(outputs[0], dim=0)
            values, idx = torch.topk(prob_outputs, beam_size)
            prev_beam = []
            next_beam = []
            resulting_captions = []
            # initialise beam
            for i in idx:
                prev_beam.append(([i], prob_outputs[i], states))
            for _ in range(max_length):
                for word_list, prob, hidden_state in prev_beam:
                    last_word = self.decoder.embed(word_list[-1]).unsqueeze(0).unsqueeze(0)
                    outs, hidden_state = self.decoder.lstm(last_word, hidden_state)
                    prob_outputs = F.log_softmax(self.decoder.linear(outs.squeeze(1))[0], dim=0)
                    values, idx = torch.topk(prob_outputs, beam_size)
                    for i in idx:
                        next_beam.append((word_list + [i], prob + prob_outputs[i], hidden_state))
                # select top beam_size from beam_size * beam_size entries
                next_beam.sort(reverse=True, key=lambda x: x[1])
                prev_beam = []
                counter = 0
                for word_list, prob, hidden_state in next_beam:
                    
                    if vocabulary[word_list[-1].item()] == '[END]':
                        resulting_captions.append((word_list, prob))
                    else:
                        prev_beam.append((word_list, prob, hidden_state))
                        counter += 1
                    if counter == beam_size:
                        break
                next_beam = []
                if prev_beam == []:
                    break
            resulting_captions.sort(reverse=True, key=lambda x: x[1])
            caption = resulting_captions[0][0] if resulting_captions != [] else []
            if caption == []:
                return ['[START]', '[END]']
            else:
                return [vocabulary[idx.item()] for idx in caption]

In [None]:
IMAGE_DIR = 'train_data'

# Creating the Dataset
train_dataset = ImageCaptionsDataset(
    IMAGE_DIR, captions_preprocessing_obj.captions_dict, img_transform=img_transform,
    captions_transform=captions_preprocessing_obj.captions_transform
)

# Define your hyperparameters
NUMBER_OF_EPOCHS = 1
LEARNING_RATE = 1e-3
BATCH_SIZE = 8
NUM_WORKERS = 2 # Parallel threads for dataloading

# define model parameters
embed_size = 512
hidden_size = 512
vocab_size = len(captions_preprocessing_obj.vocab.keys())
num_layers = 1

net = ImageCaptionsNet(embed_size, hidden_size, vocab_size, num_layers, False)
net = net.cuda()

loss_function = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE, momentum=0.9, nesterov=True)
optimizer = optim.Adam(net.parameters())


# Creating the DataLoader for batching purposes
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
net.train()
step = 0
for epoch in range(NUMBER_OF_EPOCHS):
    for batch_idx, sample in enumerate(train_loader):
        # clear gradients
        step += 1
        image_batch, image_90_batch, image_180_batch, image_270_batch = sample['image'], sample['image_90'], sample['image_180'], sample['image_270']
        captions_batch, lengths_batch = sample['captions'], sample['lengths']
        image_batch, image_90_batch, image_180_batch, image_270_batch = image_batch.float(), image_90_batch.float(), image_180_batch.float(), image_270_batch.float()
        image_batch, image_90_batch, image_180_batch, image_270_batch, captions_batch = image_batch.cuda(), image_90_batch.cuda(), image_180_batch.cuda(), image_270_batch.cuda(), captions_batch.cuda()
        lengths_batch, sequence = lengths_batch.sort(descending=True)
        captions_batch = captions_batch[sequence]
        image_batch, image_90_batch, image_180_batch, image_270_batch = image_batch[sequence], image_90_batch[sequence], image_180_batch[sequence], image_270_batch[sequence]
        target_labels = torch.nn.utils.rnn.pack_padded_sequence(captions_batch, lengths_batch, batch_first=True)[0]
        
        output_captions = net(image_batch, captions_batch, lengths_batch)
        loss = loss_function(output_captions, target_labels)
        net.zero_grad()
        loss.backward()
        optimizer.step()

        output_captions = net(image_90_batch, captions_batch, lengths_batch)
        loss = loss_function(output_captions, target_labels)
        net.zero_grad()
        loss.backward()
        optimizer.step()

        output_captions = net(image_180_batch, captions_batch, lengths_batch)
        loss = loss_function(output_captions, target_labels)
        net.zero_grad()
        loss.backward()
        optimizer.step()

        output_captions = net(image_270_batch, captions_batch, lengths_batch)
        loss = loss_function(output_captions, target_labels)
        net.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print("Iteration: " + str(step) + ", Loss: " + str(loss.item()))
    print("Iteration: " + str(epoch + 1))
torch.save(net.state_dict(), "inception.pth")

In [70]:
IMAGE_DIR = 'train_data'

# Creating the Dataset
train_dataset = ImageCaptionsDataset(
    IMAGE_DIR, captions_preprocessing_obj.captions_dict, img_transform=img_transform,
    captions_transform=captions_preprocessing_obj.captions_transform
)

# Define your hyperparameters
NUMBER_OF_EPOCHS = 1
LEARNING_RATE = 1e-3
BATCH_SIZE = 8
NUM_WORKERS = 2 # Parallel threads for dataloading

# define model parameters
embed_size = 512
hidden_size = 512
vocab_size = len(captions_preprocessing_obj.vocab.keys())
num_layers = 1


# Creating the DataLoader for batching purposes
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

net = ImageCaptionsNet(embed_size, hidden_size, vocab_size, num_layers, False)
net.load_state_dict(torch.load("inception.pth", map_location=torch.device('cpu')))
net.eval()

ImageCaptionsNet(
  (encoder): Encoder(
    (model): Inception3(
      (Conv2d_1a_3x3): BasicConv2d(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2a_3x3): BasicConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2b_3x3): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (Conv2d_3b_1x1): BasicConv2d(
        (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=Tru

In [71]:
idx2word = {}
for key, value in captions_preprocessing_obj.vocab.items():
    idx2word[value] = key

In [82]:
# iter = 0
# for idx, sample in enumerate(train_loader):
#     iter += 1
#     image = sample['image'][0]
#     image = image.float()
#     image = image.unsqueeze(0)
#     caption = sample['captions'][0].tolist()
#     print([idx2word[idx] for idx in caption])
#     features = net.encoder(image)
#     l = net.caption_image_greedy(image, idx2word)
#     print(l)
#     if iter == 10:
#         break

image = io.imread('train_data/res69.jpg')
image = img_transform(image).float()
image = image.unsqueeze(0)
features = net.encoder(image)
features = features.squeeze(0)
l = net.caption_image_beam_search(image, idx2word, 1)
print(l)


['[START]', 'a', 'man', 'in', 'a', 'red', 'jacket', '[END]']


In [None]:
!pwd