<a href="https://colab.research.google.com/github/Derrc/Image-Caption-Generation/blob/main/image_captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!unzip /content/gdrive/MyDrive/ML/Flicker8k/Images.zip -d /content/gdrive/MyDrive/ML/Flicker8k/

unzip:  cannot find or open /content/gdrive/MyDrive/ML/Flicker8k/Images.zip, /content/gdrive/MyDrive/ML/Flicker8k/Images.zip.zip or /content/gdrive/MyDrive/ML/Flicker8k/Images.zip.ZIP.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from torchvision.models import resnet34
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence
from collections import defaultdict
import os
from PIL import Image
import matplotlib.pyplot as plt

from google.colab import drive
# mount google drive for saving checkpoints
drive.mount('/content/gdrive')
CDIR = '/content/gdrive/MyDrive/ML'

Mounted at /content/gdrive


# **Data Processing**

In [2]:
# read from caption and image files
IMAGE_PATH = CDIR + '/Flickr8k/Images/'
LABEL_PATH = CDIR + '/Flickr8k/captions.txt'
START_TOKEN, END_TOKEN, PAD_TOKEN = '<start>', '<end>', '<pad>'
with open(LABEL_PATH) as f:
    lines = f.read().split('\n')
    # first and last lines are not images
    lines = [line.split(',', maxsplit=1) for line in lines][1:-1]

# dictionary of image_file -> captions
captions = defaultdict(list)
for line in lines:
    image = line[0]
    # add start and end tokens
    caption = START_TOKEN + ' ' + line[1] + ' ' + END_TOKEN
    captions[image].append(caption)

# parse caption vocabulary for tokens
tokens = set()
tokens.add(PAD_TOKEN)
for caption in captions.values():
    words = [word for sentence in caption for word in sentence.split(' ')]
    tokens = tokens.union(set(words))

# total number of tokens (vocab size)
num_tokens = len(tokens)
tokens = list(tokens)
tokens_to_id = dict((token, i) for i, token in enumerate(tokens))
start_id = tokens_to_id[START_TOKEN]
end_id = tokens_to_id[END_TOKEN]

print(f'Vocabulary Size: {num_tokens}')

Vocabulary Size: 9865


In [3]:
class Flickr8k(Dataset):
    def __init__(self, captions, transform=None):
        self.data = captions
        self.images = list(self.data.keys())
        self.annotations = list(self.data.values())
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_file = self.images[index]
        annotations = self.annotations[index]
        image = Image.open(IMAGE_PATH + image_file)
        if self.transform:
            image = self.transform(image)
        return image, annotations

In [4]:
# Resize images for uniformity in training
TARGET_SIZE = (256, 256)
BATCH_SIZE = 32

transform = transforms.Compose([
    transforms.Resize(TARGET_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = Flickr8k(captions, transform=transform)
train_test_split = random_split(dataset, [0.9, 0.1])
traindata, testdata = train_test_split[0], train_test_split[1]
trainloader = DataLoader(traindata, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testdata, batch_size=1, shuffle=True)


# **Decoder and Encoder Networks**

In [5]:
# ResNet-34 CNN Encoder
class Encoder(nn.Module):
    def __init__(self, output_dim=14):
        super().__init__()
        resnet = resnet34(pretrained=True)
        layers = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*layers)
        # adaptive pool layer so encoder can take images of different sizes
        self.resize = nn.AdaptiveAvgPool2d((output_dim, output_dim))

        self.fine_tune()

    def forward(self, x):
        x = self.resnet(x)
        x = self.resize(x)
        x = x.permute(0, 2, 3, 1)
        return x

    # disable learning up to first three res blocks
    def fine_tune(self):
        for l in list(self.resnet.children())[:5]:
            for p in l.parameters():
                p.requires_grad = False

# Soft-Attention Network
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        # [b_size, image_size, encoder_dim]
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        # [b_size, decoder_dim]
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    # takes in features from encoder and hidden layer from decoder
    def forward(self, features, hidden):
        att_features = self.encoder_att(features)
        att_hidden = self.decoder_att(hidden)
        att_cat = self.relu(att_features + att_hidden.unsqueeze(1))
        alpha_logits = self.att(att_cat).squeeze(2)
        # [b_size, image_size]
        alpha = self.softmax(alpha_logits)
        # weighted values for each pixel in feature map
        features_weighted = (features * alpha.unsqueeze(2)).sum(dim=1)

        return features_weighted, alpha


class Decoder(nn.Module):
    def __init__(self, decoder_dim, attention_dim, num_tokens, embed_size, device, encoder_dim=512):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.attention_dim = attention_dim
        self.num_tokens = num_tokens
        self.device = device

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)

        self.init_h0 = nn.Linear(encoder_dim, decoder_dim)
        self.init_c0 = nn.Linear(encoder_dim, decoder_dim)

        self.embedding = nn.Embedding(num_tokens, embed_size)
        self.lstm = nn.LSTMCell(embed_size + encoder_dim, decoder_dim)
        self.dropout = nn.Dropout(p=0.4)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, num_tokens)
        
    def initialize(self, features):
        # [b_size, image_size, encoder_dim]
        features = features.mean(dim=1)
        h0 = self.init_h0(features)
        c0 = self.init_c0(features)
        
        return h0, c0

    # captions: [b_size, max_length]
    def forward(self, features, captions, caption_lengths):
        batch_size = features.shape[0]

        # [b_size, image_size, encoder_dim]
        features = features.reshape(batch_size, -1, self.encoder_dim)

        # sort captions and features in descending order by caption length
        caption_lengths, sort_indices = caption_lengths.sort(descending=True)
        captions = captions[sort_indices]
        features = features[sort_indices]

        h, c = self.initialize(features)
        # [b_size, max_length, embed_size]
        embedding = self.embedding(captions)

        # exclude end token
        decode_lengths = (caption_lengths - 1).tolist()
        max_length = max(decode_lengths)
        # storage tensors
        logits = torch.zeros(batch_size, max_length, self.num_tokens).to(self.device)
        alphas = torch.zeros(batch_size, max_length, features.shape[1]).to(self.device)

        for t in range(max_length):
            batch_t = sum([l > t for l in decode_lengths])
            # [b_size, encoder_dim]
            features_weighted, alpha = self.attention(features[:batch_t], h[:batch_t])
            # pass weighted features through gate (from paper)
            gate = self.sigmoid(self.f_beta(h[:batch_t]))
            features_weighted = features_weighted * gate

            # cat: [b_size, embed_size], [b_size, encoder_dim]
            input = torch.cat((embedding[:batch_t, t, :], features_weighted), dim=1)
            h, c = self.lstm(input, (h[:batch_t], c[:batch_t]))

            logit = self.fc(self.dropout(h))
            logits[:batch_t, t, :] = logit
            alphas[:batch_t, t, :] = alpha

        return logits, alphas, captions, decode_lengths
    
    # naivce greedy caption generation (BEAM search implemented below)
    def generate_caption(self, features, seed_phrase='<start>', max_length=25): 
        features = features.reshape(1, -1, self.encoder_dim)
        embedding = self.embedding(torch.tensor([tokens_to_id[seed_phrase]]).to(self.device))
        h, c = self.initialize(features)

        output = []
        for i in range(max_length):
            features_weighted, alpha = self.attention(features, h)
            gate = self.sigmoid(self.f_beta(h))
            features_weighted = features_weighted * gate

            input = torch.cat((embedding, features_weighted), dim=1)
            h, c = self.lstm(input, (h, c))

            logits = self.fc(h).squeeze(0)
            # greedy
            next_token = torch.argmax(torch.softmax(logits, dim=0))
            # break if end token reached
            if next_token == end_id:
                break
 
            output.append(tokens[next_token])
            embedding = self.embedding(torch.tensor([next_token]).to(self.device))

        return ' '.join(output)

# **Training**

In [None]:
ENCODER_PATH = CDIR + '/Image-Captioning/encoder.pth'
DECODER_PATH = CDIR + '/Image-Captioning/decoder.pth'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ALPHA_COEF = 1

encoder = Encoder().to(DEVICE)
decoder = Decoder(256, 256, num_tokens, 300, DEVICE).to(DEVICE)
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optim = torch.optim.Adam(decoder.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

if os.path.exists(ENCODER_PATH):
    encoder.load_state_dict(torch.load(ENCODER_PATH))
if os.path.exists(DECODER_PATH):
    decoder.load_state_dict(torch.load(DECODER_PATH))


In [8]:
# return lengths of all sentences in captions (for padding purposes)
def max_caption_length(captions):
    return max(map(len, [sentence.split(' ') for caption in captions for sentence in caption]))
 
 # one-hot encoding of tokens in captions
def get_matrix_and_lengths(captions):
    max_length = max_caption_length(captions)
    # [5, b_size, max_length]
    matrix = np.zeros((len(captions), len(captions[0]), max_length), dtype=np.int64) + tokens_to_id[PAD_TOKEN]
    lengths = np.zeros((len(captions), len(captions[0])), dtype=np.int64)
    for i, caption in enumerate(captions):
        for j, sentence in enumerate(caption):
            lengths[i][j] = len(sentence.split(' '))
            for k, word in enumerate(sentence.split(' ')):
                matrix[i][j][k] = tokens_to_id[word]

    return matrix, lengths

def plot(total_loss):
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.plot(total_loss)
    plt.title('Loss')
    plt.show()

def imshow(image, caption):
    plt.title(f'{caption}')
    plt.imshow(image.squeeze(0).permute(1, 2, 0).cpu().numpy())

def display(encoder, decoder):
    encoder.eval()
    decoder.eval()
    image = next(iter(testloader))[0].to(DEVICE)
    features = encoder(image)
    caption = decoder.generate_caption(features)

    imshow(image, caption)

    encoder.train()
    decoder.train()


def train(epochs):
    total_loss = []
    # steps before evaluating/plotting
    iters = 5
    encoder.train()
    decoder.train()
    for e in range(epochs):
        for i, data in enumerate(trainloader):
            images, captions = data
            images = images.to(DEVICE)
            matrix, lengths = get_matrix_and_lengths(captions)
            encoded_captions = torch.tensor(matrix, dtype=torch.int64).to(DEVICE)
            caption_lengths = torch.tensor(lengths, dtype=torch.int64).to(DEVICE)
            
            # mini-mini-batch of each caption
            for c in range(len(encoded_captions)):
                caption = encoded_captions[c]
                caption_length = caption_lengths[c]

                features = encoder(images)
                logits, alphas, sorted_caption, decode_lengths = decoder(features, caption, caption_length)

                # get next tokens and exclude padded timesteps
                next_tokens = sorted_caption[:, 1:]
                next_tokens = pack_padded_sequence(next_tokens, decode_lengths, batch_first=True)[0]
                logits = pack_padded_sequence(logits, decode_lengths, batch_first=True)[0]

                loss = criterion(logits, next_tokens)
                # alpha regularization as shown in paper
                loss += ALPHA_COEF * ((1 - alphas.sum(dim=1)).pow(2)).mean()

                encoder_optim.zero_grad()
                decoder_optim.zero_grad()

                loss.backward()

                encoder_optim.step()
                decoder_optim.step()

                total_loss.append(loss.detach().cpu().numpy())

            if (i+1) % iters == 0:
                print(f'[{e+1}, {i+1}] Loss: {np.mean(total_loss[-iters]):.3f}')
                plot(total_loss)
                display(encoder, decoder)

                torch.save(encoder.state_dict(), ENCODER_PATH)
                torch.save(decoder.state_dict(), DECODER_PATH)


In [None]:
epochs = 20
train(epochs)

# **Beam Search**

In [13]:
# caption given image
def beam_search(encoder, decoder, image, k=3, max_length=35):
    # image preprocessing

    features = encoder(image)
    # [1, num_pixels, 512]
    features = features.reshape(1, -1, features.shape[3])
    # initialize with batch of k (considering top k candidates for first step)
    features = features.expand(k, features.shape[1], features.shape[2])

    # storage tensors: [k, 1]
    k_prev_words = torch.tensor([[start_id]] * k, dtype=torch.int64).to(DEVICE)
    k_seq = torch.tensor([[start_id]] * k, dtype=torch.int64).to(DEVICE)
    top_k_scores = torch.zeros(k, 1).to(DEVICE)

    # storage lists
    seqs_done = list()
    seqs_done_scores = list()

    step = 1
    h, c = decoder.initialize(features)
    for step in range(max_length):
        # [k, embed_size]
        embedding = decoder.embedding(k_prev_words.squeeze(1))

        features_weighted, alpha = decoder.attention(features, h)
        gate = torch.sigmoid(decoder.f_beta(h))
        # [k, encoder_dim]
        features_weighted = features_weighted * gate

        input = torch.cat((embedding, features_weighted), dim=1)
        h, c = decoder.lstm(input, (h, c))
        # [k, num_tokens]
        logits = decoder.fc(h)
        # generate all possible [step, step+1] scores, pick top k
        scores = torch.log_softmax(logits, dim=1)

        # scores = sum of log probs
        scores = top_k_scores.expand_as(scores) + scores
        top_k_scores, top_k_tokens = scores.flatten().topk(k, 0, True, True)

        # gets which previous k seq the tokens are part of
        prev_k_indices = top_k_tokens / num_tokens
        prev_k_indices = prev_k_indices.long()
        # gets token_ids
        next_k_indices = top_k_tokens % num_tokens

        # add to previous sequences
        k_seq = torch.cat((k_seq[prev_k_indices], next_k_indices.unsqueeze(1)), dim=1)

        # find all indices that reached <end>
        indices_not_done = [i for i, token in enumerate(next_k_indices) if token != end_id]
        indices_done = list(set(range(len(next_k_indices))) - set(indices_not_done))

        if len(indices_done) > 0:
            seqs_done.extend(k_seq[indices_done].tolist())
            seqs_done_scores.extend(top_k_scores[indices_done].tolist())
            k -= len(indices_done)

        # break if all sequences are terminated
        if k == 0:
            break
        
        # update variables to continue from chosen prev sequences
        prev_seq = prev_k_indices[indices_not_done]
        k_seq = k_seq[indices_not_done]
        h = h[prev_seq]
        c = c[prev_seq]
        features = features[prev_seq]
        top_k_scores = top_k_scores[indices_not_done].unsqueeze(1)
        k_prev_words = next_k_indices[indices_not_done].unsqueeze(1)

    
    ind = np.argmax(seqs_done_scores)
    best_seq = seqs_done[ind]

    best_seq = [tokens[id] for id in best_seq]

    return ' '.join(best_seq)

# **Evaluate Captions From Test Dataset**

In [None]:
encoder.eval()
decoder.eval()
image = next(iter(testloader))[0].to(DEVICE)
features = encoder(image)
caption = beam_search(encoder, decoder, image)

imshow(image, caption)