In [None]:
!git clone https://github.com/CornerSiow/zero-shot-image-captioning.git

Cloning into 'zero-shot-image-captioning'...
remote: Enumerating objects: 151, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 151 (delta 29), reused 0 (delta 0), pack-reused 89[K
Receiving objects: 100% (151/151), 74.78 MiB | 32.64 MiB/s, done.
Resolving deltas: 100% (68/68), done.
mv: cannot stat 'zero-shot-image-captioning/Vocabulary.py': No such file or directory


In [None]:
!cp "zero-shot-image-captioning/code/Vocabulary.py" "Vocabulary.py"
!cp "zero-shot-image-captioning/code/DecoderLSTM.py" "DecoderLSTM.py"

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import pickle
from torch.utils.data import DataLoader
from Vocabulary import Vocabulary
from DecoderLSTM import DecoderLSTM
import random
import numpy as np
random.seed(10)
torch.manual_seed(10)
np.random.seed(10)

In [None]:
vocab = Vocabulary()
vocab.loadFile("zero-shot-image-captioning/data/vocab.pickle")

In [None]:
with open('zero-shot-image-captioning/data/filtered_symbolic.pickle', 'rb') as handle:
    filtered_symbolic = pickle.load(handle)  
with open('zero-shot-image-captioning/data/training_data.pickle', 'rb') as handle:
    dataList= pickle.load(handle)

In [None]:
def collate_fn(data):
    x = []
    y = []
    for _x, _y in data:
        x.append(_x.float())
        y.append(_y)        
    y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True)
    return torch.vstack(x), y

print("Total Training Data: ", len(dataList));
trainLoader = DataLoader(dataList, batch_size = 1, shuffle = True, collate_fn =collate_fn)

Total Training Data:  5


In [None]:
with_RVS = True
vocab_size = len(vocab)
embed_size = len(filtered_symbolic)
hidden_size = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
decoder = DecoderLSTM(embed_size, hidden_size, vocab_size)
decoder.to(device)
params = decoder.parameters()
criterion = nn.CrossEntropyLoss()
criterion.to(device)
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9,0.999), eps=1e-8)

In [None]:
decoder.train()
print("Start Training")
bar = tqdm(range(1000))
for epoch in bar:
    totalLoss = 0
    for x, y in trainLoader:       
        if with_RVS:
          r = torch.rand(x.shape)        
          x = r * x
        
        decoder.zero_grad()
        outputs = decoder(x.to(device), y.to(device))
        loss = criterion(outputs.view(-1, vocab_size), y.view(-1).to(device))
        loss.backward()
        optimizer.step()
        totalLoss += loss.item()
        bar.set_description("Epoch:{:d} Loss:{:.4f}".format(epoch, totalLoss))

Start Training


Epoch:999 Loss:0.0402: 100%|██████████| 1000/1000 [01:35<00:00, 10.50it/s]


In [None]:
print("Finish Train")
# save model
print("Save the model")
torch.save(decoder.state_dict(), 'lstm_decoder.pkl')

Finish Train
Save the model
