In [119]:
import os
import numpy as np
import h5py
import json
import torch
from tqdm import tqdm
from collections import Counter
import random as random
from PIL import Image 
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision as tv
from nltk.translate.bleu_score import corpus_bleu
from torch.nn.utils.rnn import pack_padded_sequence

In [120]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [121]:
karpathy_json_path = 'data/dataset_flickr8k.json'
MAX_LENGTH = 50

In [122]:
#get the training splits across all data set 
with open(karpathy_json_path, 'r') as j:
        flickr8k_datasplit = json.load(j)

In [123]:
for key in flickr8k_datasplit:
    print(key)

images
dataset


In [124]:
image_folder = 'data/Flicker8k_Dataset'

In [125]:
def get_train_val_test_splits(image_folder, image_splits, max_caption_len, min_word_freq):
    '''
        method will return three lists containing the training, validation, and testing 
        file paths and captions. Method will also return the word map that will be used for captioning
        images. 
        Parameters: 
            image_folder: location of the images in the data set, either MSCOCO, Flickr8K, Flickr30k 
            
            image_splits: dictionary that contains the captions for each image and indicates whether 
            image part of training, validation, or testing set
            
            max_caption_len: threshold for maximum caption length 
            
            min_word_freq: threshold that determines whether a word will be in word map or not. 
            
            
        Output:
            train_img_caps: list of tuples containing the training image file path and caption
            val_img_caps: list of tuples containing the validation image file path and caption
            test_img_caps: list of tuples containing the testing image file path and caption
        '''
    #storing tuple of path to img and the caption
    word_freq = Counter()
    train_img_caps = []
    val_img_caps = []
    test_img_caps = [] 
    num_train_img, num_val_img, num_test_img = 0, 0, 0 
    for img in image_splits['images']:
        img_captions = []
        for word in img['sentences']:
            #check if the caption length is not to long
            if len(word['tokens']) <= max_caption_len:
                img_captions.append(word['tokens'])
            # Update word frequency
            word_freq.update(word['tokens'])

        #if caption is of length zero move to next image 
        if not len(img_captions): 
            continue 

        img_file_path = os.path.join(image_folder, img['filename'])
        #save corresponding files and captions 
        if img['split'] == 'train':
            train_img_caps.append((img_file_path, img_captions))
            num_train_img+=1
        elif img['split'] == 'val':
            val_img_caps.append((img_file_path, img_captions))
            num_val_img+=1 
        elif img['split'] == 'test':
            test_img_caps.append((img_file_path, img_captions))
            num_test_img+=1
    
    #create a limited vocabulary and don't include any word that hasn't appeared 
    #min_word_freq times
    words = [w for w in word_freq if word_freq[w] >= min_word_freq]
    min_words = [w for w in word_freq if word_freq[w] < min_word_freq]
    word_map = {word: i+1 for i, word in enumerate(words)}
    #specify start token, end token, unknown token, and padding token 
    word_map['<START>'] = len(word_map) + 1 
    word_map['<END>'] = len(word_map) + 1
    word_map['<UNK>'] = len(word_map) + 1
    word_map['<PAD>'] = 0
    
    print("Number of training images: {0}".format(num_train_img))
    print("Number of validation images: {0}".format(num_val_img))
    print("Number of testing images: {0}".format(num_test_img))
    return train_img_caps, val_img_caps, test_img_caps, word_map

In [126]:
train_data, val_data, test_data, word_map = get_train_val_test_splits(image_folder, flickr8k_datasplit, 50, 5)

Number of training images: 6000
Number of validation images: 1000
Number of testing images: 1000


In [129]:
def create_dataset(data, split, word_map, base_file_name, captions_per_image):
    output_folder = 'data/'
    encoded_captions = []
    encoded_captions_length = []
    start_token = word_map['<START>']
    end_token = word_map['<END>']
    unknown_token = word_map['<UNK>']
    padding_token = word_map['<PAD>']
    training_data_file = os.path.join(output_folder, base_file_name + '_' + split + '_images.hdf5')
    encoded_captions_file = os.path.join(output_folder, base_file_name + '_' + split + '_encoded_captions.json')
    encoded_captions_length_file = os.path.join(output_folder, base_file_name + '_' + split + '_encoded_caption_lengths.json')
    
    print("Creating %s data set" % split)
    with h5py.File(os.path.join(output_folder, base_file_name + '_' + split + '_images' + '.hdf5'), 'a') as h:
        images = h.create_dataset('images', (len(data), 3, 256, 256), dtype='uint8')
        for image_idx ,(image_path, image_captions) in enumerate(data):
            
            #want to ensure that there are at least certain number of captions per image 
            #if current image has less than that threshold, then augement the captions
            num_captions = len(image_captions)
            if num_captions < captions_per_image: 
                chosen_captions = [random.choice(image_captions) for _ in range(captions_pe_image - num_captions)]
                chosen_captions += image_captions
            else:
                chosen_captions = random.sample(image_captions, k = captions_per_image)
            
            #for the chosen captions, encode them
            
            for i, caption in enumerate(chosen_captions):
#                 import pdb; pdb.set_trace()
                encoded_caption = [word_map.get(w,unknown_token) for w in caption]
                assert len(caption) == len(encoded_caption)
                padding_for_caption = [padding_token for _ in range(MAX_LENGTH- len(caption))]
                encoded_caption = [start_token] + encoded_caption + [end_token] + padding_for_caption
                
                encoded_captions.append(encoded_caption)
                

                assert len(encoded_caption) == MAX_LENGTH + 2 
                encoded_captions_length.append(len(caption) + 2)
            
            #resize all images to be 256 x 256 
            image = Image.open(image_path)
            image_resize = image.resize((256, 256))
            image_array = np.asarray(image_resize).transpose(2,0,1) #ensures that 3x256x256
            images[image_idx] = image_array
            
            
            assert len(image_array.shape) == 3
            
        h.attrs['cpi'] = captions_per_image 
        
        print("Saving the encoded captions")
        #save the encoded captions and the encoded caption lengths to a json file 
        with open(encoded_captions_file, 'w') as j:
            json.dump(encoded_captions, j)

        with open(encoded_captions_length_file, 'w') as j:
            json.dump(encoded_captions_length, j)
        
        print("Done creating the dataset for split ")

In [149]:
create_dataset(test_data, 'test', word_map, 'flickr8k', 5)

Creating test data set
Saving the encoded captions
Done creating the dataset for split 


In [145]:
class MyDataset(Dataset):
    def __init__(self, folder, name, split, transform=None):
        '''
            Create a data set class that will be used when passing into the data loader. 
        '''
        self.split = split
        
        self.file = h5py.File(os.path.join(folder, name + '_' + self.split + '_images.hdf5'))
        self.images = self.file['images']
        
        self.cpi = self.file.attrs['cpi']
        
        # load captions
        with open(os.path.join(folder, name + '_' + self.split + '_encoded_captions.json'), 'r') as f:
            self.captions = json.load(f)
            
        # load captions' lenghts
        with open(os.path.join(folder, name + '_' + self.split + '_encoded_caption_lengths.json'), 'r') as f:
            self.lengths = json.load(f)
        
                        
    def __getitem__(self, idx):
        image = torch.FloatTensor(self.images[idx // self.cpi] / 255.0)
        
        #TODO: not using standard formulation of mean=0, std=1
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        image = normalize(image)
        
        caption, caption_length = torch.LongTensor(self.captions[idx]), torch.LongTensor([self.lengths[idx]])
        
        if self.split == 'train':
            return image, caption, caption_length
        else:
            start = self.cpi * (idx // self.cpi)
            end = start + self.cpi
            
            captions = torch.LongTensor(self.captions[start:end])
            
            
            return image, caption, caption_length, captions
        
    def __len__(self):
        return len(self.captions)
        

In [150]:

train_set = MyDataset('data', 'flickr8k', 'train')
val_set = MyDataset('data', 'flickr8k', 'val')
test_set = MyDataset('data', 'flickr8k', 'test')

train_loader = torch.utils.data.DataLoader(train_set, batch_size=32,shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, pin_memory=True)

In [153]:
for i, (images, captions, lengths, ac) in enumerate(val_loader):
    print(images.shape, captions.shape, lengths.shape, ac.shape)

torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52])

torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52]) torch.Size([32, 1]) torch.Size([32, 5, 52])
torch.Size([32, 3, 256, 256]) torch.Size([32, 52])

In [154]:
class Encoder(nn.Module):
    def __init__(self, dim_size=14):
        super(Encoder, self).__init__()
        resnet = tv.models.resnet101(pretrained=True)

        modules = list(resnet.children())[:-2]

        self.resnet = nn.Sequential(*modules)

        self.pool = nn.AdaptiveAvgPool2d((dim_size, dim_size))
        
        #TODO: decided not to fine tune blocks 2-4
        
    def forward(self, images):
        return self.pool(self.resnet(images)).permute(0, 2, 3, 1)

    
class Attention(nn.Module):
    def __init__(self, dim_encoder, dim_decoder, dim_attention):
        super(Attention, self).__init__()
        
        self.attention_encoder = nn.Linear(dim_encoder, dim_attention)
        self.attention_decoder = nn.Linear(dim_decoder, dim_attention)
        self.both = nn.Linear(dim_attention, 1)
    
    def forward(self, out_encoder, hidden_decoder):
        attention_encoder = self.attention_encoder(out_encoder)
        attention_decoder = self.attention_decoder(hidden_decoder)
        
        weights = self.both(torch.relu(attention_encoder + attention_decoder.unsqueeze(1))).squeeze(2)
        weights = torch.softmax(weights, dim=1)
        
        out = torch.sum((out_encoder * weights.unsqueeze(2)), dim=1)
        
        return out, weights 

class Decoder(nn.Module):
    def __init__(self, dim_attention, dim_embed, dim_decoder, vocab_size, dim_encoder=2048):
        super(Decoder, self).__init__()
        
        self.dim_encoder = dim_encoder
        self.dim_attention = dim_attention
        self.dim_embed = dim_embed
        self.vocab_size = vocab_size
        
        self.attention = Attention(dim_encoder, dim_decoder, dim_attention)
        self.embed = nn.Embedding(vocab_size, dim_embed)
        self.drop = nn.Dropout(p=0.5)
        
        self.decode_lstm = nn.LSTMCell(dim_embed + dim_encoder, dim_decoder, bias=True)
        self.h_init = nn.Linear(dim_encoder, dim_decoder)
        self.c_init = nn.Linear(dim_encoder, dim_decoder)
        self.f = nn.Linear(dim_decoder, dim_encoder)
        
        self.fc1 = nn.Linear(dim_decoder, vocab_size)
        
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.fc1.bias.data.fill_(0)
        self.fc1.weight.data.uniform_(-0.1, 0.1)
        
    def init_hidden(self, out_encoder):
        out = out_encoder.mean(dim=1)
        
        return self.h_init(out), self.c_init(out)
    
    def forward(self, out_encoder, captions, lengths):
        batch_size = out_encoder.size(0)
        dim_encoder = out_encoder.size(-1)
        vocab_size = self.vocab_size
        
        out_encoder = out_encoder.view(batch_size, -1, dim_encoder)
        pixels = out_encoder.size(1)
        
        lengths, ind = lengths.squeeze(1).sort(dim=0, descending=True)
        out_encoder = out_encoder[ind]
        captions = captions[ind]
        
        embed = self.embed(captions)
        
        # init hidden state
        h, c = self.init_hidden(out_encoder)
        
        lengths = (lengths-1).tolist()
        
        predict = torch.zeros(batch_size, max(lengths), vocab_size).to(device)
        weights = torch.zeros(batch_size, max(lengths), pixels).to(device)
        
        for time_step in range(max(lengths)):
            batch_t = sum([i > time_step for i in lengths])
            
            weighted_encoder, alpha = self.attention(out_encoder[:batch_t], h[:batch_t])
            
            sig = torch.sigmoid(self.f(h[:batch_t]))
            weighted_encoder = sig * weighted_encoder
            
            h, c = self.decode_lstm(torch.cat([embed[:batch_t, time_step, :], weighted_encoder], dim=1),
                                    (h[:batch_t], c[:batch_t]))
            
            output = self.fc1(self.drop(h))
            predict[:batch_t, time_step, :] = output
            weights[:batch_t, time_step, :] = alpha
            
        return predict, captions, lengths, weights, ind

In [155]:
encoder = Encoder().to(device)
decoder = Decoder(dim_attention=512, dim_embed=512, dim_decoder=512, vocab_size=len(word_map)).to(device)

Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /tmp/xdg-cache/torch/checkpoints/resnet101-5d3b4d8f.pth
100%|██████████| 170M/170M [00:02<00:00, 66.1MB/s] 


In [156]:
encoder_optimizer = torch.optim.Adam(params=encoder.parameters(), lr=1e-4)
decoder_optimizer = torch.optim.Adam(params=decoder.parameters(), lr=4e-4)

criterion = nn.CrossEntropyLoss().to(device)

In [157]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    decoder.train()
    encoder.train()
    
    loss_sum = 0
    loss_num = 0
    
    for i, (images, captions, lengths) in enumerate(train_loader):
        images = images.to(device)
        captions = captions.to(device)
        lengths = lengths.to(device)
        
        images = encoder(images)
        predict, captions, lengths, weights, ind = decoder(images, captions, lengths)
        
        targets = captions[:,1:]
        
        predict,_ = pack_padded_sequence(predict, lengths, batch_first=True)
        targets,_ = pack_padded_sequence(targets, lengths, batch_first=True)
        
        loss = criterion(predict, targets)
        
        loss += ((1. - weights.sum(dim=1)) ** 2).mean()
        
        decoder_optimizer.zero_grad()
        encoder_optimizer.zero_grad()
        
        loss.backward()
        
        loss_sum += loss.item()*sum(lengths)
        loss_num += sum(lengths)
        
        if (i % 100) == 0:
            print("Epoch % d [%d/%d], Loss: %f" % (epoch, i, len(train_loader), loss_sum/loss_num))

In [158]:
def validate(val_loader, encoder, decoder, criterion):
    decoder.eval()
    encoder.eval()
    
    refs = []
    hypos = []
    
    loss_sum = 0
    loss_num = 0
    
    with torch.no_grad():
        for i, (images, captions, lengths, all_captions) in enumerate(val_loader):
            images = images.to(device)
            captions = captions.to(device)
            lengths = lengths.to(device)
            
            images = encoder(images)
            predict, captions, lengths, weights, ind = decoder(images, captions, lengths)
            
            targets = captions[:,1:]
            
            scores = predict.clone()
            
            predict,_ = pack_padded_sequence(predict, lengths, batch_first=True)
            targets,_ = pack_padded_sequence(targets, lengths, batch_first=True)
            
            loss = criterion(predict, targets)
            
            loss += ((1. - weights.sum(dim=1)) ** 2).mean()
                     
            loss_sum += loss.item()*sum(lengths)
            loss_num += sum(lengths)
                     
            if (i % 100) == 0:
                print("Validate [%d/%d], Loss: %f" % (i, len(val_laoder, loss_sum/loss_num)))
            
            all_captions = all_captions[ind]
            for j in range(all_captions.shape[0]):
                image_captions = all_captions[j].tolist()
                image_captions = list(map(lambda c : [w for w in c if w not in {word_map['<START>'], 
                                                                                word_map['<PAD>']}],
                                         image_captions))
                refs.append(image_captions)
            
            _, pred = torch.max(scores, dim=2).tolist()
            temp = []
            for j, p in enumerate(pred):
                temp.append(pred[j][:lengths[j]])
            pred = temp
            hypos.extend(pred)
            
            assert len(refs) == len(hypos)
            
        bleu4 = corpus_bleu(refs, hypos)
        
        print("Validate. BLEU-4: %f, Loss: %f" % (bleu4, loss_sum/loss_num))
        
    return bleu4

In [182]:
for epoch in range(120):
    train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch)
    bleu4 = validate(val_loader, encoder, decoder, criterion)

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 14 and 19 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:689