### Deep Reinforcement Learning-based Image Captioning with Embedding Reward
Pranshu Gupta, Deep Learning @ Georgia Institute of Technology

In [154]:
# As usual, a bit of setup
from __future__ import print_function
import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import nltk

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn

from cs231n.coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from cs231n.image_utils import image_from_url

from torchsummary import summary

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Working on: ", device)

def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Working on:  cuda:0


### Load MS-COCO data
We will use the Microsoft COCO dataset for captioning.

In [140]:
# Load COCO data from disk; this returns a dictionary
# We'll work with dimensionality-reduced features for this notebook, but feel
# free to experiment with the original features by changing the flag below.
data = load_coco_data(pca_features=True)

data["train_captions_lens"] = np.zeros(data["train_captions"].shape[0])
data["val_captions_lens"] = np.zeros(data["val_captions"].shape[0])
for i in range(data["train_captions"].shape[0]):
    data["train_captions_lens"][i] = np.nonzero(data["train_captions"][i] == 2)[0][0] + 1
for i in range(data["val_captions"].shape[0]):
    data["val_captions_lens"][i] = np.nonzero(data["val_captions"][i] == 2)[0][0] + 1


# Print out all the keys and values from the data dictionary
for k, v in data.items():
    if type(v) == np.ndarray:
        print(k, type(v), v.shape, v.dtype)
    else:
        print(k, type(v), len(v))

train_captions <class 'numpy.ndarray'> (400135, 17) int32
train_image_idxs <class 'numpy.ndarray'> (400135,) int32
val_captions <class 'numpy.ndarray'> (195954, 17) int32
val_image_idxs <class 'numpy.ndarray'> (195954,) int32
train_features <class 'numpy.ndarray'> (82783, 512) float32
val_features <class 'numpy.ndarray'> (40504, 512) float32
idx_to_word <class 'list'> 1004
word_to_idx <class 'dict'> 1004
train_urls <class 'numpy.ndarray'> (82783,) <U63
val_urls <class 'numpy.ndarray'> (40504,) <U63
train_captions_lens <class 'numpy.ndarray'> (400135,) float64
val_captions_lens <class 'numpy.ndarray'> (195954,) float64


### Caption Evaluation

In [8]:
def BLEU_score(gt_caption, sample_caption):
    """
    gt_caption: string, ground-truth caption
    sample_caption: string, your model's predicted caption
    Returns unigram BLEU score.
    """
    reference = [x for x in gt_caption.split(' ') 
                 if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)]
    hypothesis = [x for x in sample_caption.split(' ') 
                  if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)]
    BLEUscore = nltk.translate.bleu_score.sentence_bleu([reference], hypothesis, weights = [1])
    return BLEUscore

def evaluate_model(model):
    """
    model: CaptioningRNN model
    Prints unigram BLEU score averaged over 1000 training and val examples.
    """
    BLEUscores = {}
    for split in ['train', 'val']:
        minibatch = sample_coco_minibatch(data, split=split, batch_size=1000)
        gt_captions, features, urls = minibatch
        gt_captions = decode_captions(gt_captions, data['idx_to_word'])

        sample_captions = model.sample(features)
        sample_captions = decode_captions(sample_captions, data['idx_to_word'])

        total_score = 0.0
        for gt_caption, sample_caption, url in zip(gt_captions, sample_captions, urls):
            total_score += BLEU_score(gt_caption, sample_caption)

        BLEUscores[split] = total_score / len(sample_captions)

    for split in BLEUscores:
        print('Average BLEU score for %s: %f' % (split, BLEUscores[split]))

### Policy Network

In [141]:
class PolicyNetwork(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(PolicyNetwork, self).__init__()
        
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        
        vocab_size = len(word_to_idx)
        
        self.null = word_to_idx['<NULL>']
        self.start = word_to_idx.get('<START>', None)
        self.end = word_to_idx.get('<END>', None)
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        
        self.cnn2linear = nn.Linear(input_dim, hidden_dim)
        self.lstm = nn.LSTM(wordvec_dim, hidden_dim, batch_first=True)
        self.linear2vocab = nn.Linear(hidden_dim, vocab_size)
        self.probs = nn.Softmax(dim=1)
        
    def forward(self, features, captions):
        input_captions = self.caption_embedding(captions)
        hidden_init = self.cnn2linear(features)
        cell_init = torch.zeros_like(hidden_init)
        output, _ = self.lstm(input_captions, (hidden_init, cell_init))
        output = self.linear2vocab(output)
        return output

### Training the Policy Network

In [174]:
policyNetwork = PolicyNetwork(data["word_to_idx"]).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(policyNetwork.parameters(), lr=0.0001)

In [179]:
small_data = load_coco_data(max_train=5000)

In [180]:
batch_size = 50
for epoch in range(10000):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float().unsqueeze(0)
    captions_in = torch.tensor(captions[:, :-1], device=device).long()
    captions_ou = torch.tensor(captions[:, 1:], device=device).long()
    output = policyNetwork(features, captions_in)
    
    loss = 0
    for i in range(batch_size):
        caplen = np.nonzero(captions[i] == 2)[0][0] + 1
        loss += (caplen/batch_size)*criterion(output[i][:caplen], captions_ou[i][:caplen])
    
    print(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

27.79432487487793
27.849533081054688
27.58513069152832
27.21681022644043
27.63636589050293
25.90688133239746
26.322227478027344
25.373939514160156
27.0172119140625
26.972564697265625
26.703039169311523
27.454120635986328
25.543241500854492
25.674583435058594
26.441699981689453
29.13265037536621
25.3128719329834
27.321428298950195
25.9634952545166
30.09293556213379
27.7325496673584
27.94649314880371
24.76369857788086
26.091564178466797
28.10906982421875
24.65267562866211
26.953224182128906
24.66082000732422
26.02132797241211
27.454591751098633
27.32353973388672
23.738489151000977
28.0628604888916
28.438858032226562
25.434661865234375
26.316926956176758
25.262582778930664
26.811782836914062
26.032957077026367
25.807418823242188
27.193553924560547
25.387460708618164
26.97427749633789
27.020227432250977
25.15505027770996
28.873199462890625
26.605329513549805
25.473217010498047
23.93617820739746
24.721664428710938
24.98944854736328
23.934255599975586
26.70647621154785
27.844886779785156
24.

22.645687103271484
21.97199821472168
22.470836639404297
21.06716537475586
24.048362731933594
23.15460777282715
23.966880798339844
22.162729263305664
23.804338455200195
22.906126022338867
23.846351623535156
22.51207160949707
22.391454696655273
21.20732307434082
21.163246154785156
27.542184829711914
24.25726890563965
21.149694442749023
21.76241683959961
22.83191680908203
23.928850173950195
22.765165328979492
19.811464309692383
21.6955509185791
20.231046676635742
20.649961471557617
21.744834899902344
21.3497314453125
21.884807586669922
22.64280891418457
23.826995849609375
23.46689224243164
22.188884735107422
22.697906494140625
22.175413131713867
23.403059005737305
24.06847381591797
24.333877563476562
21.89696502685547
23.51311683654785
22.63047218322754
23.18398094177246
22.984363555908203
22.55386734008789
22.750141143798828
20.74024200439453
21.84124755859375
22.65003776550293
23.06637954711914
22.198423385620117
20.72493553161621
22.968273162841797
23.209575653076172
22.231731414794922

19.36012077331543
20.380081176757812
20.591659545898438
18.464773178100586
19.584346771240234
20.647756576538086
20.021907806396484
19.590078353881836
21.67384147644043
19.363222122192383
21.397781372070312
21.05316734313965
20.514493942260742
20.207273483276367
19.69734764099121
20.443838119506836
19.31246566772461
18.87904930114746
19.971071243286133
19.172292709350586
19.388654708862305
20.402530670166016
18.756546020507812
19.78031349182129
19.495223999023438
21.937179565429688
19.20139503479004
19.422136306762695
17.47979164123535
19.4014949798584
20.058124542236328
18.577594757080078
20.702543258666992
20.569438934326172
20.193504333496094
19.57501983642578
18.914730072021484
22.296972274780273
21.617692947387695
20.102937698364258
19.556209564208984
21.49709701538086
19.97968292236328
19.642227172851562
18.455245971679688
21.126127243041992
19.562767028808594
19.621843338012695
19.116348266601562
19.86553192138672
19.720985412597656
20.587221145629883
19.567201614379883
18.96635

16.52035903930664
16.756528854370117
17.512603759765625
16.47405242919922
16.82382583618164
17.400257110595703
17.498798370361328
16.167015075683594
18.561479568481445
18.53936195373535
16.96645164489746
16.050161361694336
16.977142333984375
17.69694709777832
16.97746467590332
17.37704849243164
17.269468307495117
17.077611923217773
16.294111251831055
17.67597007751465
16.637929916381836
16.739334106445312
17.607873916625977
17.00547981262207
16.68943214416504
17.795129776000977
17.446752548217773
15.599882125854492
17.901718139648438
17.679229736328125
16.582874298095703
18.56984519958496
17.390270233154297
17.444656372070312
17.502845764160156
17.148841857910156
17.1641845703125
17.189069747924805
18.53083610534668
17.129283905029297
16.333742141723633
16.129941940307617
19.091035842895508
17.662424087524414
17.67222023010254
16.310266494750977
16.795438766479492
16.719974517822266
17.831628799438477
17.912647247314453
16.220932006835938
17.45115852355957
16.337800979614258
18.1019878

13.982637405395508
15.119560241699219
14.367569923400879
14.995641708374023
16.315549850463867
14.590584754943848
14.793878555297852
15.694168090820312
14.548624992370605
15.646646499633789
16.883480072021484
14.826050758361816
15.637417793273926
16.8961124420166
15.781402587890625
13.327326774597168
15.323782920837402
15.60424518585205
14.650141716003418
14.027310371398926
14.382247924804688
15.413471221923828
15.304827690124512
15.875590324401855
15.18087387084961
15.06418514251709
15.058929443359375
13.179152488708496
16.391103744506836
14.680070877075195
15.06024169921875
15.208877563476562
14.330772399902344
14.491387367248535
14.52420711517334
14.994158744812012
14.481712341308594
14.2969388961792
15.974528312683105
14.030848503112793
14.665898323059082
16.41237449645996
13.35745620727539
14.134611129760742
15.97082805633545
16.595582962036133
15.670831680297852
15.759856224060059
15.673979759216309
13.627920150756836
15.131420135498047
14.183320045471191
14.567517280578613
16.05

13.679970741271973
11.871960639953613
12.928191184997559
14.440319061279297
12.423202514648438
13.872703552246094
12.564227104187012
13.36518383026123
13.750725746154785
13.456587791442871
12.764060020446777
13.569356918334961
12.04643440246582
12.204291343688965
12.873522758483887
13.466581344604492
13.788949966430664
13.161148071289062
11.986165046691895
12.434239387512207
13.589369773864746
14.111196517944336
11.450149536132812
12.5341215133667
12.805606842041016
13.482734680175781
14.744122505187988
11.309704780578613
14.079022407531738
12.632120132446289
12.70294189453125
13.655792236328125
13.443853378295898
13.115306854248047
15.128591537475586
12.057496070861816
12.921547889709473
14.18010425567627
14.41728401184082
12.936777114868164
13.624520301818848
11.687973022460938
13.860184669494629
12.767793655395508
14.44638442993164
13.496790885925293
13.675024032592773
13.453093528747559
13.939045906066895
13.475329399108887
13.708029747009277
14.599228858947754
12.442903518676758
1

10.601192474365234
11.283592224121094
11.608179092407227
11.197319030761719
11.126643180847168
12.748832702636719
11.50821304321289
11.08484172821045
11.154123306274414
11.152732849121094
10.241779327392578
11.836747169494629
12.238824844360352
12.335317611694336
11.195847511291504
11.208220481872559
11.038739204406738
12.57010555267334
11.581209182739258
10.279369354248047
11.984508514404297
10.267718315124512
11.239900588989258
11.109984397888184
12.498908996582031
10.280421257019043
10.678519248962402
11.661758422851562
12.168126106262207
10.453099250793457
10.821184158325195
11.435990333557129
11.55788803100586
11.651598930358887
10.575665473937988
11.444045066833496
10.781128883361816
11.938196182250977
11.453885078430176
10.948177337646484
10.909830093383789
12.656447410583496
12.021135330200195
11.322613716125488
10.935972213745117
11.928808212280273
10.89395523071289
11.050812721252441
10.095694541931152
11.484929084777832
12.058274269104004
10.733806610107422
12.18430900573730

9.957572937011719
9.70686149597168
10.074828147888184
10.079229354858398
8.308733940124512
9.920793533325195
10.061873435974121
9.304940223693848
9.524828910827637
10.220166206359863
9.925686836242676
10.082130432128906
9.209993362426758
9.975050926208496
8.960357666015625
9.1747465133667
8.36665153503418
9.272363662719727
8.405049324035645
9.023518562316895
9.589479446411133
10.150930404663086
10.641060829162598
9.44835090637207
10.08356761932373
9.83733081817627
9.94266414642334
9.80445384979248
10.20321273803711
10.13640022277832
8.58151912689209
9.284523963928223
9.629415512084961
9.025362014770508
10.174905776977539
10.177773475646973
10.096685409545898
8.438508033752441
9.383747100830078
10.233128547668457
10.394465446472168
9.439654350280762
9.485889434814453
8.675834655761719
9.967422485351562
9.601719856262207
10.007267951965332
9.55290699005127
9.789595603942871
10.561944961547852
9.968168258666992
9.826604843139648
9.173372268676758
9.913694381713867
10.532221794128418
9.608

9.438272476196289
7.981037139892578
7.8755106925964355
7.721837520599365
8.042274475097656
7.167961597442627
8.318634986877441
8.55938720703125
8.545491218566895
8.64275074005127
8.278905868530273
8.44921875
8.162833213806152
8.485028266906738
7.774521827697754
7.2935075759887695
7.979465961456299
8.144752502441406
8.108306884765625
7.388726234436035
8.301942825317383
8.72939395904541
7.714054107666016
9.22988224029541
8.546542167663574
7.462831497192383
8.095844268798828
8.504453659057617
8.176345825195312
7.9396820068359375
8.835857391357422
7.784792423248291
8.329707145690918
8.834925651550293
8.087481498718262
9.557615280151367
8.032142639160156
7.67951774597168
8.42294979095459
7.629621505737305
8.336410522460938
8.056780815124512
9.371013641357422
8.276070594787598
8.491046905517578
8.40596866607666
8.073868751525879
7.737828254699707
7.803447723388672
8.083803176879883
8.064772605895996
8.447308540344238
7.401754856109619
7.980551242828369
9.06160831451416
7.706363677978516
8.29

6.797475814819336
6.539464950561523
6.731550216674805
7.517207622528076
6.996216773986816
7.183610439300537
7.316282749176025
7.588812828063965
7.0385847091674805
6.838772773742676
6.579854965209961
6.826920509338379
7.496461868286133
7.951927185058594
6.790221691131592
7.421730041503906
6.915374755859375
6.426033020019531
7.908452987670898
6.688291072845459
7.092153549194336
6.889214992523193
8.073698043823242
6.349565029144287
7.977138042449951
7.051569938659668
6.8467912673950195
7.313314914703369
6.799627304077148
6.6270751953125
7.899942398071289
7.158297061920166
7.414275169372559
6.859832286834717
6.871603488922119
6.190738201141357
7.216708183288574
7.087217330932617
7.319047927856445
6.826152324676514
7.237839698791504
7.2679853439331055
6.7745137214660645
6.757894039154053
6.60120964050293
6.654378890991211
7.53122615814209
7.795788288116455
6.577693939208984
6.6534504890441895
6.830066680908203
6.661509990692139
5.942878246307373
6.769339561462402
7.102340221405029
7.1422624

5.646081924438477
6.066290855407715
5.891241073608398
6.390215873718262
5.2119598388671875
5.3352251052856445
5.821944236755371
5.290273666381836
6.032908916473389
5.658734321594238
5.532431602478027
6.122658729553223
6.532634735107422
6.960811138153076
4.9438605308532715
6.320159435272217
7.180797576904297
6.5845417976379395
5.508238315582275
6.417459487915039
5.739360332489014
5.742857456207275
5.77394962310791
5.521547794342041
5.662650108337402
5.462189197540283
5.827879905700684
6.650187015533447
5.622256278991699
5.425473690032959
5.714205741882324
5.1435441970825195
5.303990840911865
5.7022929191589355
6.537487983703613
6.398095607757568
5.813839912414551
5.487791538238525
6.784912586212158
5.205014228820801
6.36372709274292
6.125710964202881
5.928740501403809
5.550624847412109
5.613574028015137
5.830404281616211
6.186218738555908
5.840057373046875
5.927513599395752
5.7294840812683105
6.584293365478516
5.443726539611816
5.644827365875244
5.5316996574401855
5.564546585083008
5.32

4.880376815795898
5.10748815536499
4.703894138336182
4.609005928039551
4.241820812225342
4.666171073913574
5.52567195892334
4.682855606079102
4.740664482116699
5.412676811218262
4.855100154876709
5.193206310272217
4.033217430114746
4.214886665344238
4.539727687835693
4.964273452758789
4.932886600494385
4.802255630493164
5.369803428649902
5.054605960845947
5.00862979888916
5.331244468688965
5.130323886871338
4.942473888397217
4.807243824005127
4.557916164398193
4.349671840667725
4.912477493286133
4.939394950866699
4.7383856773376465
4.960498809814453
5.20286750793457
4.582770347595215
4.721676349639893
4.89276647567749
5.143474578857422
4.329163074493408
5.0998148918151855
4.918550491333008
4.984801769256592
4.584273338317871
4.487154006958008
4.7970123291015625
4.9026384353637695
4.570830821990967
4.140643119812012
5.243320465087891
4.478581428527832
5.03907585144043
4.5975751876831055
5.274941444396973
4.7415032386779785
4.645212650299072
5.916082382202148
4.788846015930176
4.47075700

4.133557319641113
4.151971340179443
4.431673526763916
4.026630878448486
4.3368916511535645
4.155606746673584
3.987581968307495
4.203391075134277
3.9891903400421143
4.459466934204102
4.225549697875977
4.572200298309326
4.0131072998046875
3.6200954914093018
3.997751474380493
3.7804534435272217
3.8744637966156006
4.559208393096924
4.412356853485107
3.9827539920806885
4.314429759979248
4.673576354980469
4.196414470672607
3.9758286476135254
3.9447786808013916
3.8911640644073486
3.9407498836517334
4.469218730926514
4.007474899291992
4.077794075012207
4.413944244384766
3.9901533126831055
4.262862205505371
3.739070177078247
3.8654441833496094
3.8681888580322266
4.122035980224609
3.630289077758789
4.496960639953613
3.8257997035980225
4.047530651092529
3.950883626937866
3.6263396739959717
3.5733847618103027
4.09175968170166
3.977271556854248
3.9261393547058105
4.429078578948975
3.4688475131988525
3.9725875854492188
3.8494155406951904
4.607769012451172
3.826348304748535
4.131279945373535
3.447322

3.472062349319458
3.269742965698242
3.0994808673858643
3.2593464851379395
3.293266773223877
3.7105650901794434
3.619325876235962
3.5361454486846924
3.2206802368164062
3.423383951187134
3.6472256183624268
3.007664442062378
3.583427906036377
2.9676125049591064
3.7029879093170166
3.7680587768554688
3.2413671016693115
3.5627567768096924
3.45615291595459
3.133977174758911
3.2441859245300293
3.9644408226013184
3.4710938930511475
3.3100790977478027
3.8805346488952637
3.2122726440429688
4.019044399261475
3.824751615524292
3.234926462173462
3.731358289718628
3.4828684329986572
3.0259957313537598
3.6469907760620117
3.695901870727539
2.9453742504119873
3.2369093894958496
3.2908294200897217
3.3257369995117188
3.3919739723205566
3.277883291244507
3.4762163162231445
3.228614330291748
3.4800429344177246
3.5386927127838135
3.6131174564361572
3.375802755355835
3.2203571796417236
3.166097402572632
2.9947354793548584
3.366330623626709
3.091693162918091
3.2432010173797607
3.2063217163085938
3.261802196502

2.8456308841705322
2.7538371086120605
2.5109357833862305
2.5586423873901367
2.662992238998413
3.2699553966522217
2.39713191986084
3.127855062484741
2.6357388496398926
2.617767333984375
2.414078712463379
2.667907238006592
2.4563605785369873
2.8313775062561035
2.7556560039520264
2.9986016750335693
3.2021942138671875
2.7180588245391846
2.3251800537109375
2.6903669834136963
3.2029855251312256
2.6114115715026855
2.650925874710083
2.9070494174957275
2.981837511062622
2.7016541957855225
2.675189971923828
2.6503500938415527
2.65751051902771
3.1020126342773438
2.5909981727600098
2.9347176551818848
2.7311689853668213
2.512232542037964
2.8714044094085693
2.7103793621063232
2.604987144470215
2.7616753578186035
2.960127115249634
3.2365148067474365
2.6302988529205322
3.322997570037842
2.5136709213256836
2.4038901329040527
2.6620874404907227
2.2436628341674805
2.894465446472168
3.080726146697998
2.7663803100585938
2.8317761421203613
2.645958185195923
2.7314255237579346
2.7142863273620605
2.6017167568

2.3646984100341797
2.1906445026397705
2.331962823867798
2.7708632946014404
2.285006523132324
2.6245317459106445
2.421614170074463
2.2428359985351562
2.5431149005889893
2.5636043548583984
2.2417526245117188
1.9939689636230469
2.347167730331421
2.076061487197876
2.3158233165740967
2.432284355163574
2.37306547164917
2.9211082458496094
2.4388768672943115
2.349294900894165
2.5316975116729736
2.295234203338623
2.4178531169891357
2.1498172283172607
2.094367504119873
2.3698108196258545
2.3264331817626953
2.627167224884033
2.3173792362213135
2.1724677085876465
2.080319881439209
2.3259732723236084
2.283827066421509
2.345904588699341
2.2500193119049072
2.498084306716919
2.4281771183013916
2.2296371459960938
2.240288019180298
2.6001100540161133
2.244569778442383
2.2339131832122803
2.2940218448638916
2.7260730266571045
2.2551727294921875
2.0502712726593018
2.409636974334717
2.2020211219787598
2.319844961166382
1.9886376857757568
2.089564323425293
2.661435842514038
2.561631202697754
2.48927211761474

1.7722101211547852
1.7166502475738525
1.5601776838302612
1.813938856124878
1.8100863695144653
1.7942177057266235
2.1109917163848877
2.2029805183410645
2.110881805419922
1.756364107131958
1.777757167816162
2.221489667892456
1.7421057224273682
1.7721527814865112
1.9033981561660767
1.7809640169143677
2.041257619857788
1.7846264839172363
1.5707311630249023
2.1098973751068115
1.9129817485809326
1.703641414642334
1.6173125505447388
1.9693758487701416
2.4656729698181152
2.1207878589630127
1.8642805814743042
1.6844542026519775
1.8987021446228027
1.8061877489089966
1.935288429260254
1.7282568216323853
1.8295862674713135
1.6562645435333252
1.836205244064331
2.209872245788574
1.9226832389831543
1.8174701929092407
2.018500566482544
1.8739373683929443
1.857859492301941
1.6927950382232666
1.974280595779419
1.9467153549194336
1.846671462059021
1.8862875699996948
2.028974771499634
1.9589612483978271
1.9714117050170898
1.8432785272598267
1.7835177183151245
1.9696623086929321
1.9858050346374512
1.668984

1.4396843910217285
1.7619305849075317
1.473737359046936
1.5868626832962036
1.6430686712265015
1.5452076196670532
1.6144897937774658
1.6565735340118408
1.4749139547348022
1.6764832735061646
1.6448661088943481
1.3388980627059937
1.512075424194336
1.7886871099472046
1.5974417924880981
1.500626802444458
1.4243898391723633
1.4533106088638306
1.5608984231948853
1.4394749402999878
1.6810773611068726
1.3672524690628052
1.5003726482391357
1.6617604494094849
1.7885003089904785
1.7939234972000122
1.3742177486419678
1.401535987854004
1.4863578081130981
1.8856226205825806
1.5234788656234741
1.5098040103912354
1.699149489402771
1.7466946840286255
1.5755265951156616
1.8078515529632568
1.48793625831604
1.5387599468231201
1.613789677619934
1.4875802993774414
1.7512321472167969
1.657499074935913
1.5623767375946045
1.5829459428787231
1.568034291267395
1.5248708724975586
1.568907618522644
1.9403668642044067
1.6029316186904907
1.727192997932434
1.6731038093566895
1.8956081867218018
1.6019126176834106
1.672

1.38996422290802
1.7320414781570435
1.5475906133651733
1.2917667627334595
1.5030083656311035
1.4216439723968506
1.3183163404464722
1.6443372964859009
1.3987400531768799
1.4308935403823853
1.4174567461013794
1.4694641828536987
1.3857070207595825
1.3731106519699097
1.3809155225753784
1.3708841800689697
1.4710900783538818
1.4057210683822632
1.3377629518508911
1.1750208139419556
1.3474676609039307
1.2446227073669434
1.2953944206237793
1.3943536281585693
1.2818009853363037
1.252991795539856
1.503596544265747
1.3189975023269653
1.5094836950302124
1.4160722494125366
1.352739930152893
1.3220213651657104
1.4235858917236328
1.4250081777572632
1.220444917678833
1.5119348764419556
1.7897030115127563
1.3619388341903687
1.4370970726013184
1.311895489692688
1.3354594707489014
1.4508970975875854
1.9927202463150024
1.4583548307418823
1.442586898803711
1.1948089599609375
1.4930613040924072
1.3356050252914429
1.394526720046997
1.0570341348648071
1.334246039390564
1.4849330186843872
1.638704776763916
1.37

1.1075220108032227
0.9892150163650513
1.1029433012008667
1.0973210334777832
1.2005414962768555
1.0709452629089355
0.9974992871284485
0.9816206693649292
1.0307118892669678
1.0229960680007935
1.0876872539520264
1.0674195289611816
1.0325641632080078
1.0789984464645386
1.1379584074020386
1.1409926414489746
1.1131643056869507
1.0467504262924194
1.045341968536377
1.0707398653030396
0.9954755902290344
1.0751475095748901
1.0601716041564941
1.2066667079925537
1.0542856454849243
1.0569038391113281
0.9372356534004211
1.1548691987991333
1.008109211921692
1.0607197284698486
1.038305640220642
0.9824002981185913
1.0786226987838745
1.1007968187332153
1.150234580039978
1.2123398780822754
1.1725157499313354
1.159534215927124
1.0927202701568604
1.046735167503357
1.1801866292953491
1.0658364295959473
1.0350922346115112
1.0218030214309692
1.1034564971923828
1.0962965488433838
1.1207234859466553
0.9690706133842468
1.0532832145690918
1.0423859357833862
0.971973180770874
1.062219262123108
0.9497756958007812
1

0.8118641972541809
0.8906490802764893
0.8206893801689148
1.0155236721038818
0.7938759922981262
0.8803284764289856
0.8121793866157532
0.866904079914093
0.915050745010376
0.8442425727844238
0.8682120442390442
0.8728612065315247
0.822287917137146
0.8311588764190674
0.8060469627380371
0.9171326160430908
0.9683420658111572
1.0109647512435913
0.7497732639312744
0.827099621295929
0.9166668057441711
0.8573763370513916
0.7325689792633057
0.8732240200042725
0.8033145666122437
0.9245516657829285
0.7461654543876648
0.8659541606903076
0.8695783019065857
0.8202531337738037
0.8894567489624023
0.7882283329963684
0.9467157125473022
0.8347476720809937
0.9533082842826843
0.7484440803527832
0.8772665858268738
0.7901433706283569
0.7820915579795837
0.7898567318916321
0.8308117389678955
0.8684874176979065
0.7935225367546082
0.7597172856330872
0.8597131371498108
0.8781669735908508
0.7237095832824707
0.8181728720664978
1.018088936805725
0.7949816584587097
0.8094342350959778
0.8274657130241394
0.734370291233062

1.063804268836975
0.9352747797966003
1.0120127201080322
1.1093230247497559
1.3006316423416138
0.9630376696586609
1.0270135402679443
1.0188028812408447
1.1856794357299805
1.0561977624893188
0.9570326209068298
0.95456862449646
0.8290500044822693
0.8681653738021851
0.9823846817016602
0.9614357948303223
1.046783447265625
0.8578094840049744
0.9263325333595276
0.847405731678009
1.1664537191390991
1.0772500038146973
1.0596717596054077
0.904279887676239
0.9192148447036743
0.9485023617744446
0.980672299861908
1.0823224782943726
1.1121426820755005
1.10347580909729
0.8682475686073303
0.938096284866333
1.317232608795166
0.8198254704475403
0.9235323071479797
0.9932847619056702
0.958946168422699
0.9292382597923279
0.7984744906425476
1.112023115158081
1.3172451257705688
0.8226951956748962
0.8997275233268738
1.2642768621444702
0.9008576273918152
0.9698373675346375
0.8584187030792236
0.9460120797157288
1.1538989543914795
0.9058398008346558
0.9573050737380981
1.0500810146331787
1.533713936805725
0.91236

0.7123891711235046
0.5716068744659424
0.7551313042640686
0.8369406461715698
0.7362849116325378
0.6994541883468628
0.8007568120956421
0.7520700693130493
0.8509842753410339
0.7136957049369812
0.7326259613037109
0.7084788680076599
0.7244458198547363
0.7527125477790833
0.7016711831092834
0.7690642476081848
0.7846200466156006
0.9870568513870239
0.7905513644218445
0.6886200308799744
0.694147527217865
0.7799153923988342
0.695128321647644
0.7964924573898315
0.9885350465774536
0.6819671988487244
0.7142418622970581
0.8329972624778748
0.7314987182617188
0.8330836892127991
0.8770111203193665
0.7625212669372559
0.6808443069458008
0.7264063954353333
0.785702645778656
0.6901622414588928
0.6942127346992493
0.7268478870391846
0.6483681201934814
0.7258593440055847
0.6495645642280579
0.6797704100608826
0.7557748556137085
0.6862217783927917
0.832503080368042
0.7653828263282776
0.8967788815498352
0.6431872248649597
0.8155799508094788
0.8114484548568726
0.7692607045173645
0.7195530533790588
0.65279608964920

### Value Network

In [314]:
class ValueNetworkRNN(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(ValueNetworkRNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        vocab_size = len(word_to_idx)
        
        self.hidden_cell = (torch.zeros(1, 1, self.hidden_dim).to(device), torch.zeros(1, 1, self.hidden_dim).to(device))
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        self.lstm = nn.LSTM(wordvec_dim, hidden_dim)
        
    def forward(self, captions):
        input_captions = self.caption_embedding(captions)
        output, self.hidden_cell = self.lstm(input_captions.view(len(input_captions) ,1, -1), self.hidden_cell)
        return output
    
class ValueNetwork(nn.Module):
    def __init__(self, word_to_idx):
        super(ValueNetwork, self).__init__()
        self.valrnn = ValueNetworkRNN(word_to_idx)
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 1)
    
    def forward(self, features, captions):
        for t in range(captions.shape[1]):
            vrnn = self.valrnn(captions[:, t])
        vrnn = vrnn.squeeze(0).squeeze(1)
        state = torch.cat((features, vrnn), dim=1)
        output = self.linear1(state)
        output = self.linear2(output)
        return output

In [315]:
valueNetwork = ValueNetwork(data["word_to_idx"]).to(device)
# criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(valueNetwork.parameters(), lr=0.0001)

In [185]:
small_data = load_coco_data(max_train=5000)

In [319]:
batch_size = 10
for epoch in range(100):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float()
    captions = torch.tensor(captions, device=device).long()
    output = valueNetwork(features, captions)

### Reward Network

In [329]:
class RewardNetworkRNN(nn.Module):
    def __init__(self, word_to_idx, input_dim=512, wordvec_dim=512, hidden_dim=512, dtype=np.float32):
        super(RewardNetworkRNN, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        vocab_size = len(word_to_idx)
        
        self.hidden_cell = torch.zeros(1, 1, self.hidden_dim).to(device)
        
        self.caption_embedding = nn.Embedding(vocab_size, wordvec_dim)
        self.gru = nn.GRU(wordvec_dim, hidden_dim)
    
    def forward(self, captions):
        input_captions = self.caption_embedding(captions)
        output, self.hidden_cell = self.gru(input_captions.view(len(input_captions) ,1, -1), self.hidden_cell)
        return output
    
class RewardNetwork(nn.Module):
    def __init__(self, word_to_idx):
        super(RewardNetwork, self).__init__()
        self.rewrnn = RewardNetworkRNN(word_to_idx)
        self.visual_embed = nn.Linear(512, 512)
        self.semantic_embed = nn.Linear(512, 512)
        
    def forward(self, features, captions):
        for t in range(captions.shape[1]):
            rrnn = self.rewrnn(captions[:, t])
        rrnn = rrnn.squeeze(0).squeeze(1)
        se = self.semantic_embed(rrnn)
        ve = self.visual_embed(features)
        return ve, se

In [330]:
rewardNetwork = RewardNetwork(data["word_to_idx"]).to(device)
# criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(rewardNetwork.parameters(), lr=0.0001)

In [332]:
small_data = load_coco_data(max_train=5000)

In [333]:
batch_size = 10
for epoch in range(100):
    captions, features, _ = sample_coco_minibatch(small_data, batch_size=batch_size, split='train')
    features = torch.tensor(features, device=device).float()
    captions = torch.tensor(captions, device=device).long()
    ve, se = rewardNetwork(features, captions)
#     print(ve.shape, se.shape)