### Deep Reinforcement Learning-based Image Captioning with Embedding Reward
Pranshu Gupta, Deep Learning @ Georgia Institute of Technology

In [116]:
# As usual, a bit of setup
from __future__ import print_function
import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import nltk
import random

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F

from utils.coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from utils.image_utils import image_from_url

from torchsummary import summary

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Working on: ", device)

def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

max_seq_len = 17

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


import metrics

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Working on:  cuda:0


### Load MS-COCO data
We will use the Microsoft COCO dataset for captioning.

In [30]:
# Load COCO data from disk; this returns a dictionary
# We'll work with dimensionality-reduced features for this notebook, but feel
# free to experiment with the original features by changing the flag below.
data = load_coco_data(pca_features=True)

data["train_captions_lens"] = np.zeros(data["train_captions"].shape[0])
data["val_captions_lens"] = np.zeros(data["val_captions"].shape[0])
for i in range(data["train_captions"].shape[0]):
    data["train_captions_lens"][i] = np.nonzero(data["train_captions"][i] == 2)[0][0] + 1
for i in range(data["val_captions"].shape[0]):
    data["val_captions_lens"][i] = np.nonzero(data["val_captions"][i] == 2)[0][0] + 1


# Print out all the keys and values from the data dictionary
for k, v in data.items():
    if type(v) == np.ndarray:
        print(k, type(v), v.shape, v.dtype)
    else:
        print(k, type(v), len(v))

train_captions <class 'numpy.ndarray'> (400135, 17) int32
train_image_idxs <class 'numpy.ndarray'> (400135,) int32
val_captions <class 'numpy.ndarray'> (195954, 17) int32
val_image_idxs <class 'numpy.ndarray'> (195954,) int32
train_features <class 'numpy.ndarray'> (82783, 512) float32
val_features <class 'numpy.ndarray'> (40504, 512) float32
idx_to_word <class 'list'> 1004
word_to_idx <class 'dict'> 1004
train_urls <class 'numpy.ndarray'> (82783,) <U63
val_urls <class 'numpy.ndarray'> (40504,) <U63
train_captions_lens <class 'numpy.ndarray'> (400135,) float64
val_captions_lens <class 'numpy.ndarray'> (195954,) float64


In [31]:
small_data = load_coco_data(max_train=50000)

### Policy Network

In [33]:
class PolicyNetwork(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(PolicyNetwork, self).__init__()
        
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        
        vocab_size = len(word_to_idx)
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        
        self.cnn2linear = nn.Linear(input_dim, hidden_dim)
        self.lstm = nn.LSTM(wordvec_dim, hidden_dim, batch_first=True)
        self.linear2vocab = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, features, captions):
        input_captions = self.caption_embedding(captions)
        hidden_init = self.cnn2linear(features)
        cell_init = torch.zeros_like(hidden_init)
        output, _ = self.lstm(input_captions, (hidden_init, cell_init))
        output = self.linear2vocab(output)
        return output

### Reward Network

In [34]:
class RewardNetworkRNN(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(RewardNetworkRNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        vocab_size = len(word_to_idx)
        
        self.hidden_cell = torch.zeros(1, 1, self.hidden_dim).to(device)
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        self.gru = nn.GRU(wordvec_dim, hidden_dim)
    
    def forward(self, captions):
        input_captions = self.caption_embedding(captions)
        output, self.hidden_cell = self.gru(input_captions.view(len(input_captions) ,1, -1), self.hidden_cell)
        return output
    
class RewardNetwork(nn.Module):
    def __init__(self, word_to_idx):
        super(RewardNetwork, self).__init__()
        self.rewrnn = RewardNetworkRNN(word_to_idx)
        self.visual_embed = nn.Linear(512, 512)
        self.semantic_embed = nn.Linear(512, 512)
        
    def forward(self, features, captions):
        for t in range(captions.shape[1]):
            rrnn = self.rewrnn(captions[:, t])
        rrnn = rrnn.squeeze(0).squeeze(1)
        se = self.semantic_embed(rrnn)
        ve = self.visual_embed(features)
        return ve, se

### Value Network

In [35]:
class ValueNetworkRNN(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(ValueNetworkRNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        vocab_size = len(word_to_idx)
        
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_dim).to(device), torch.zeros(1, 1, self.hidden_dim).to(device))
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        self.lstm = nn.LSTM(wordvec_dim, hidden_dim)
        
    def forward(self, captions):
        input_captions = self.caption_embedding(captions)
        output, self.hidden_cell = self.lstm(input_captions.view(len(input_captions) ,1, -1), self.hidden_cell)
        return output
    
class ValueNetwork(nn.Module):
    def __init__(self, word_to_idx):
        super(ValueNetwork, self).__init__()
        self.valrnn = ValueNetworkRNN(word_to_idx)
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 1)
    
    def forward(self, features, captions):
        for t in range(captions.shape[1]):
            vrnn = self.valrnn(captions[:, t])
        vrnn = vrnn.squeeze(0).squeeze(1)
        state = torch.cat((features, vrnn), dim=1)
        output = self.linear1(state)
        output = self.linear2(output)
        return output

### Generating Captions

In [67]:
policyNet = PolicyNetwork(data["word_to_idx"]).to(device)
policyNet.load_state_dict(torch.load('policyNetwork.pt'))
policyNet.train(mode=False)

valueNet = ValueNetwork(data["word_to_idx"]).to(device)
valueNet.load_state_dict(torch.load('valueNetwork.pt'))
valueNet.train(mode=False)

ValueNetwork(
  (valrnn): ValueNetworkRNN(
    (caption_embedding): Embedding(1004, 512)
    (lstm): LSTM(512, 512)
  )
  (linear1): Linear(in_features=1024, out_features=512, bias=True)
  (linear2): Linear(in_features=512, out_features=1, bias=True)
)

#### Greedy Caption Generator

In [68]:
def GenerateCaptions(features, captions, model):
    features = torch.tensor(features, device=device).float().unsqueeze(0)
    gen_caps = torch.tensor(captions[:, 0:1], device=device).long()
    for t in range(max_seq_len-1):
        output = model(features, gen_caps)
        gen_caps = torch.cat((gen_caps, output[:, -1:, :].argmax(axis=2)), axis=1)
    return gen_caps

#### Beam Search Caption Generator

In [69]:
def GenerateCaptionsWithBeamSearch(features, captions, model, beamSize=5):
    features = torch.tensor(features, device=device).float().unsqueeze(0)
    gen_caps = torch.tensor(captions[:, 0:1], device=device).long()
    candidates = [(gen_caps, 0)]
    for t in range(max_seq_len-1):
        next_candidates = []
        for c in range(len(candidates)):
            output = model(features, candidates[c][0])
            probs, words = torch.topk(output[:, -1:, :], beamSize)
            for i in range(beamSize):
                cap = torch.cat((candidates[c][0], words[:, :, i]), axis=1)
                score = candidates[c][1] - torch.log(probs[0, 0, i]).item()
                next_candidates.append((cap, score))
        ordered_candidates = sorted(next_candidates, key=lambda tup:tup[1])
        candidates = ordered_candidates[:beamSize]
    return candidates 

#### Lookahead Inference with Policy and Value Network

In [70]:
def GenerateCaptionsWithBeamSearchValueScoring(features, captions, model, beamSize=5):
    features = torch.tensor(features, device=device).float().unsqueeze(0)
    gen_caps = torch.tensor(captions[:, 0:1], device=device).long()
    candidates = [(gen_caps, 0)]
    for t in range(max_seq_len-1):
        next_candidates = []
        for c in range(len(candidates)):
            output = model(features, candidates[c][0])
            probs, words = torch.topk(output[:, -1:, :], beamSize)
            for i in range(beamSize):
                cap = torch.cat((candidates[c][0], words[:, :, i]), axis=1)
                value = valueNet(features.squeeze(0), cap).detach()
                score = candidates[c][1] - 0.6*value.item() -0.4*torch.log(probs[0, 0, i]).item()
                next_candidates.append((cap, score))
        ordered_candidates = sorted(next_candidates, key=lambda tup:tup[1])
        candidates = ordered_candidates[:beamSize]
    return candidates   

In [None]:
with torch.no_grad():
    max_seq_len = 17
    captions, features, urls = sample_coco_minibatch(small_data, batch_size=1000, split='val')
    for i in range(1000):
        gen_caps = []
        gen_caps.append(GenerateCaptions(features[i:i+1], captions[i:i+1], policyNet)[0])
        gen_caps.append(GenerateCaptionsWithBeamSearch(features[i:i+1], captions[i:i+1], policyNet)[0][0][0])
        gen_caps.append(GenerateCaptionsWithBeamSearchValueScoring(features[i:i+1], captions[i:i+1], policyNet)[0][0][0])
        decoded_tru_caps = decode_captions(captions[i], data["idx_to_word"])

        f = open("truth3.txt", "a")
        f.write(decoded_tru_caps + "\n")
        f.close()
        
        decoded_gen_caps = decode_captions(gen_caps[0], data["idx_to_word"])
        f = open("greedy3.txt", "a")
        f.write(decoded_gen_caps + "\n")
        f.close()
        
        decoded_gen_caps = decode_captions(gen_caps[1], data["idx_to_word"])
        f = open("beam3.txt", "a")
        f.write(decoded_gen_caps + "\n")
        f.close()
        
        decoded_gen_caps = decode_captions(gen_caps[2], data["idx_to_word"])
        f = open("policyvalue3.txt", "a")
        f.write(decoded_gen_caps + "\n")
        f.close()

### Caption Evaluation

In [87]:
def BLEU_score(gt_caption, sample_caption, w):
    """
    gt_caption: string, ground-truth caption
    sample_caption: string, your model's predicted caption
    Returns unigram BLEU score.
    """
    reference = [x for x in gt_caption.split(' ') 
                 if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)]
    hypothesis = [x for x in sample_caption.split(' ') 
                  if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)]
    BLEUscore = nltk.translate.bleu_score.sentence_bleu([reference], hypothesis, weights = [w])
    return BLEUscore

def evaluate_model(model):
    """
    model: CaptioningRNN model
    Prints unigram BLEU score averaged over 1000 training and val examples.
    """
    BLEUscores = {}
    for split in ['train', 'val']:
        minibatch = sample_coco_minibatch(data, split=split, batch_size=1000)
        gt_captions, features, urls = minibatch
        gt_captions = decode_captions(gt_captions, data['idx_to_word'])

        sample_captions = model.sample(features)
        sample_captions = decode_captions(sample_captions, data['idx_to_word'])

        total_score = 0.0
        for gt_caption, sample_caption, url in zip(gt_captions, sample_captions, urls):
            total_score += BLEU_score(gt_caption, sample_caption)

        BLEUscores[split] = total_score / len(sample_captions)

    for split in BLEUscores:
        print('Average BLEU score for %s: %f' % (split, BLEUscores[split]))

In [129]:
caps0 = []
caps1 = []
caps2 = []
caps3 = []
f = open("truth2.txt", "r")
for x in f:
    x = " ".join([w for w in x.split(' ') if ('<END>' not in w and '<START>' not in w and '<UNK>' not in w)])
    caps0.append(x)
f = open("greedy2.txt", "r")
for x in f:
    x = " ".join([w for w in x.split(' ') if ('<END>' not in w and '<START>' not in w and '<UNK>' not in w)])
    caps1.append(x)
f = open("beam2.txt", "r")
for x in f:
    x = " ".join([w for w in x.split(' ') if ('<END>' not in w and '<START>' not in w and '<UNK>' not in w)])
    caps2.append(x)
f = open("policyvalue2.txt", "r")
for x in f:
    x = " ".join([w for w in x.split(' ') if ('<END>' not in w and '<START>' not in w and '<UNK>' not in w)])
    caps3.append(x)

In [93]:
b1, b2, b3 = 0, 0, 0
for w in range(1, 5):
    for i in range(len(caps0)):
        b1 += BLEU_score(caps0[i], caps1[i], w)
        b2 += BLEU_score(caps0[i], caps2[i], w)
        b3 += BLEU_score(caps0[i], caps3[i], w)
    b1 /= len(caps0)
    b2 /= len(caps0)
    b3 /= len(caps0)
    print("Greedy BLEU-" + str(w), ":", b1)
    print("Beam BLEU-" + str(w), ":", b2)
    print("Agent BLEU-" + str(w), ":", b3)
    print()

Greedy BLEU-1 : 0.3374543171912208
Beam BLEU-1 : 0.29998119046207933
Agent BLEU-1 : 0.30057253835441977

Greedy BLEU-2 : 0.18381039209700356
Beam BLEU-2 : 0.13227059725207552
Agent BLEU-2 : 0.1331405488671185

Greedy BLEU-3 : 0.12767973218661097
Beam BLEU-3 : 0.0724795981070803
Agent BLEU-3 : 0.07345311185936992

Greedy BLEU-4 : 0.10370808190929426
Beam BLEU-4 : 0.04722721512979743
Agent BLEU-4 : 0.04818825165316483



In [131]:
ref, hypo = metrics.load_textfiles(caps0, caps3)
print(metrics.score(ref, hypo))

The number of references is 1000
{'testlen': 10255, 'reflen': 9324, 'guess': [10255, 9255, 8255, 7255], 'correct': [3285, 1002, 451, 270]}
ratio: 1.0998498498497318
{'Bleu_1': 0.32033154558748705, 'Bleu_2': 0.18622822496901026, 'Bleu_3': 0.12374191249218547, 'Bleu_4': 0.09163665014876797, 'METEOR': 0.14102364160475023, 'ROUGE_L': 0.30709802012803344, 'CIDEr': 0.8480455366206312}


### Training the Policy Network

In [44]:
pretrained = False

policyNetwork = PolicyNetwork(data["word_to_idx"]).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(policyNetwork.parameters(), lr=0.0001)

if pretrained:
    policyNetwork.load_state_dict(torch.load('models/policyNetwork.pt'))  

In [None]:
batch_size = 100
bestLoss = 0.3
#0.006700546946376562

for epoch in range(250000, 350000):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float().unsqueeze(0)
    captions_in = torch.tensor(captions[:, :-1], device=device).long()
    captions_ou = torch.tensor(captions[:, 1:], device=device).long()
    output = policyNetwork(features, captions_in)
    
    loss = 0
    for i in range(batch_size):
        caplen = np.nonzero(captions[i] == 2)[0][0] + 1
        loss += (caplen/batch_size)*criterion(output[i][:caplen], captions_ou[i][:caplen])
    
    if loss.item() < bestLoss:
        bestLoss = loss.item()
        torch.save(policyNetwork.state_dict(), "policyNetwork.pt")
        print("epoch:", epoch, "loss:", loss.item())
        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch: 256846 loss: 0.2919604778289795
epoch: 271081 loss: 0.2918578088283539
epoch: 276987 loss: 0.2405262291431427
epoch: 324988 loss: 0.2366735190153122
epoch: 339368 loss: 0.20999650657176971


### Training the Reward Network

In [57]:
rewardNetwork = RewardNetwork(data["word_to_idx"]).to(device)
optimizer = optim.Adam(rewardNetwork.parameters(), lr=0.001)

# https://cs230-stanford.github.io/pytorch-nlp.html#writing-a-custom-loss-function
def VisualSemanticEmbeddingLoss(visuals, semantics):
    beta = 0.2
    N, D = visuals.shape
    
    visloss = torch.mm(visuals, semantics.t())
    visloss = visloss - torch.diag(visloss).unsqueeze(1)
    visloss = visloss + (beta/N)*(torch.ones((N, N)).to(device) - torch.eye(N).to(device))
    visloss = F.relu(visloss)
    visloss = torch.sum(visloss)/N
    
    semloss = torch.mm(semantics, visuals.t())
    semloss = semloss - torch.diag(semloss).unsqueeze(1)
    semloss = semloss + (beta/N)*(torch.ones((N, N)).to(device) - torch.eye(N).to(device))
    semloss = F.relu(semloss)
    semloss = torch.sum(semloss)/N
    
    return visloss + semloss        

In [58]:
batch_size = 50
bestLoss = 10000

for epoch in range(50000):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float()
    captions = torch.tensor(captions, device=device).long()
    ve, se = rewardNetwork(features, captions)
    loss = VisualSemanticEmbeddingLoss(ve, se)
    
    if loss.item() < bestLoss:
        bestLoss = loss.item()
        torch.save(rewardNetwork.state_dict(), "rewardNetwork.pt")
        print("epoch:", epoch, "loss:", loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    rewardNetwork.rewrnn.hidden_cell.detach_()

epoch: 0 loss: 285.8268737792969
epoch: 78 loss: 279.0987243652344
epoch: 109 loss: 268.0894775390625
epoch: 140 loss: 258.7628173828125
epoch: 146 loss: 249.2835235595703
epoch: 184 loss: 242.68533325195312
epoch: 198 loss: 234.36572265625
epoch: 271 loss: 219.52809143066406
epoch: 276 loss: 201.01954650878906
epoch: 499 loss: 190.47842407226562
epoch: 521 loss: 184.58734130859375
epoch: 608 loss: 181.97393798828125
epoch: 614 loss: 161.46424865722656
epoch: 663 loss: 159.47470092773438
epoch: 702 loss: 128.3231964111328
epoch: 847 loss: 103.04296112060547
epoch: 1151 loss: 100.62371063232422
epoch: 1182 loss: 99.61503601074219
epoch: 1204 loss: 96.04625701904297
epoch: 1273 loss: 95.47994995117188
epoch: 1316 loss: 90.17144012451172
epoch: 1413 loss: 80.14276885986328
epoch: 1482 loss: 80.06565856933594
epoch: 1505 loss: 75.33641815185547
epoch: 1600 loss: 73.59349060058594
epoch: 1637 loss: 68.0627670288086
epoch: 1653 loss: 67.85617065429688
epoch: 1690 loss: 65.30216217041016
epoc

KeyboardInterrupt: 

In [59]:
def GetRewards(features, captions, model):
    visEmbeds, semEmbeds = model(features, captions)
    visEmbeds = F.normalize(visEmbeds, p=2, dim=1) 
    semEmbeds = F.normalize(semEmbeds, p=2, dim=1) 
    rewards = torch.sum(visEmbeds*semEmbeds, axis=1).unsqueeze(1)
    return rewards

In [62]:
rewardNet = RewardNetwork(data["word_to_idx"]).to(device)
rewardNet.load_state_dict(torch.load('rewardNetwork.pt'))
for param in rewardNet.parameters():
    param.require_grad = False
print(rewardNet)

policyNet = PolicyNetwork(data["word_to_idx"]).to(device)
policyNet.load_state_dict(torch.load('policyNetwork.pt'))
for param in policyNet.parameters():
    param.require_grad = False
print(policyNet)

valueNetwork = ValueNetwork(data["word_to_idx"]).to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(valueNetwork.parameters(), lr=0.0001)
valueNetwork.train(mode=True)

RewardNetwork(
  (rewrnn): RewardNetworkRNN(
    (caption_embedding): Embedding(1004, 512)
    (gru): GRU(512, 512)
  )
  (visual_embed): Linear(in_features=512, out_features=512, bias=True)
  (semantic_embed): Linear(in_features=512, out_features=512, bias=True)
)
PolicyNetwork(
  (caption_embedding): Embedding(1004, 512)
  (cnn2linear): Linear(in_features=512, out_features=512, bias=True)
  (lstm): LSTM(512, 512, batch_first=True)
  (linear2vocab): Linear(in_features=512, out_features=1004, bias=True)
)


ValueNetwork(
  (valrnn): ValueNetworkRNN(
    (caption_embedding): Embedding(1004, 512)
    (lstm): LSTM(512, 512)
  )
  (linear1): Linear(in_features=1024, out_features=512, bias=True)
  (linear2): Linear(in_features=512, out_features=1, bias=True)
)

In [63]:
batch_size = 50
bestLoss = 10000
max_seq_len = 17

for epoch in range(50000):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float()
    
    # Generate captions using the policy network
    captions = GenerateCaptions(features, captions, policyNet)
    
    # Compute the reward of the generated caption using reward network
    rewards = GetRewards(features, captions, rewardNet)
    
    # Compute the value of a random state in the generation process
#     print(features.shape, captions[:, :random.randint(1, 17)].shape)
    values = valueNetwork(features, captions[:, :random.randint(1, 17)])
    
    # Compute the loss for the value and the reward
    loss = criterion(values, rewards)
    
    if loss.item() < bestLoss:
        bestLoss = loss.item()
        torch.save(valueNetwork.state_dict(), "valueNetwork.pt")
        print("epoch:", epoch, "loss:", loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    valueNetwork.valrnn.hidden_cell[0].detach_()
    valueNetwork.valrnn.hidden_cell[1].detach_()
    rewardNet.rewrnn.hidden_cell.detach_()

  


epoch: 0 loss: 0.4791761636734009
epoch: 1 loss: 0.3826758861541748
epoch: 2 loss: 0.30105552077293396
epoch: 3 loss: 0.2344961166381836
epoch: 5 loss: 0.21635407209396362
epoch: 6 loss: 0.2095080465078354
epoch: 8 loss: 0.17506301403045654
epoch: 9 loss: 0.16135434806346893
epoch: 11 loss: 0.10012535750865936
epoch: 13 loss: 0.07867062091827393
epoch: 21 loss: 0.07740631699562073
epoch: 22 loss: 0.07632629573345184
epoch: 23 loss: 0.05278187245130539
epoch: 29 loss: 0.0399894043803215
epoch: 34 loss: 0.03727782517671585
epoch: 36 loss: 0.033582452684640884
epoch: 39 loss: 0.03300255537033081
epoch: 44 loss: 0.027580423280596733
epoch: 45 loss: 0.025160321965813637
epoch: 48 loss: 0.019739700481295586
epoch: 50 loss: 0.01948259025812149
epoch: 52 loss: 0.013901742175221443
epoch: 57 loss: 0.013447004370391369
epoch: 59 loss: 0.012788197956979275
epoch: 69 loss: 0.010633053258061409
epoch: 71 loss: 0.008541139774024487
epoch: 76 loss: 0.007344955112785101
epoch: 87 loss: 0.0071632359176

KeyboardInterrupt: 

## Reinforcement Learning
Advantage Actor Critic Model for Reinforcement Learning

In [64]:
class AdvantageActorCriticNetwork(nn.Module):
    def __init__(self, valueNet, policyNet):
        super(AdvantageActorCriticNetwork, self).__init__()

        self.valueNet = valueNet #RewardNetwork(data["word_to_idx"]).to(device)
        self.policyNet = policyNet #PolicyNetwork(data["word_to_idx"]).to(device)

    def forward(self, features, captions):
        # Get value from value network
        values = self.valueNet(features, captions)
        # Get action probabilities from policy network
        probs = self.policyNet(features.unsqueeze(0), captions)[:, -1:, :]        
        return values, probs 

In [65]:
rewardNet = RewardNetwork(data["word_to_idx"]).to(device)
policyNet = PolicyNetwork(data["word_to_idx"]).to(device)
valueNet = ValueNetwork(data["word_to_idx"]).to(device)

rewardNet.load_state_dict(torch.load('rewardNetwork.pt'))
policyNet.load_state_dict(torch.load('policyNetwork.pt'))
valueNet.load_state_dict(torch.load('valueNetwork.pt'))

a2cNetwork = AdvantageActorCriticNetwork(valueNet, policyNet)
optimizer = optim.Adam(a2cNetwork.parameters(), lr=0.0001)

### Curriculum Learning

In [66]:
curriculum = [2, 4, 6, 8, 10, 12, 14, 16]
episodes = 50

small_data = load_coco_data(max_train=50000)

for level in curriculum:
    
    for epoch in range(1000):        
        episodicAvgLoss = 0
        
        captions, features, _ = sample_coco_minibatch(small_data, batch_size=episodes, split='train')
        features = torch.tensor(features, device=device).float()
        captions = torch.tensor(captions, device=device).long()
        
        for episode in range(episodes):
            log_probs = []
            values = []
            rewards = []
            caplen = np.nonzero(captions[episode] == 2)[0][0] + 1
            
            if (caplen - level > 1):
                captions_in = captions[episode:episode+1, :caplen-level]
                features_in = features[episode:episode+1]

                for step in range(level):
                    value, probs = a2cNetwork(features_in, captions_in)
                    probs = F.softmax(probs, dim=2)
                    
                    dist = probs.cpu().detach().numpy()[0,0]
                    action = np.random.choice(probs.shape[-1], p=dist)
                    
                    gen_cap = torch.from_numpy(np.array([action])).unsqueeze(0).to(device)
                    captions_in = torch.cat((captions_in, gen_cap), axis=1)
                    
                    log_prob = torch.log(probs[0, 0, action])
                    
                    reward = GetRewards(features_in, captions_in, rewardNet)
                    reward = reward.cpu().detach().numpy()[0, 0]
                    
                    rewards.append(reward)
                    values.append(value)
                    log_probs.append(log_prob)
                    
            values = torch.FloatTensor(values).to(device)
            rewards = torch.FloatTensor(rewards).to(device)
            log_probs = torch.stack(log_probs).to(device)
            
            advantage = values - rewards 
            actorLoss = (-log_probs * advantage).mean()
            criticLoss = 0.5 * advantage.pow(2).mean()
            
            loss = actorLoss + criticLoss
            episodicAvgLoss += loss.item()/episodes
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, ":", episodicAvgLoss)

0 : 0.011274048415052675
1 : 0.009052214500998157
2 : 0.006383603059908865
3 : 0.0073232187853773225
4 : 0.0034985836075793475
5 : 0.0041928653282229805
6 : 0.0018523795278451878
7 : -0.0031698185374989434
8 : 0.007850690340856092
9 : 0.014547815381083636
10 : -0.00046384327535633953
11 : 0.0036360891032381913
12 : 0.007600265605142341
13 : -0.002000644241570627
14 : -0.0028894532122649254
15 : 0.014178393029305885
16 : -0.002927510830195387
17 : -0.0004923985352070305
18 : 0.003888461572641971
19 : -7.64965608823334e-06
20 : 0.008021804855125084
21 : 0.006459024113355553
22 : -0.007037104693663421
23 : 0.005563439957350054


KeyboardInterrupt: 