In [1]:
import torch
from torch import nn

import numpy as np
from os.path import join

from tqdm import tqdm

import pickle
import re

from collections import Counter

In [2]:
datadir = 'C:\\associative_represenations_data\\'

In [3]:
f = open('tfeats.pkl', 'rb')
tfeats = pickle.load(f)
    
f = open('vfeats.pkl', 'rb')
vfeats = pickle.load(f)

f = open('tcapts.pkl', 'rb')
tcapts = pickle.load(f)
    
f = open('vcapts.pkl', 'rb')
vcapts = pickle.load(f)

In [4]:
print(tfeats.shape)
print(vfeats.shape)

(118287, 2048)
(5000, 2048)


In [5]:
print(tcapts.shape)
print(vcapts.shape)

(118287, 5)
(5000, 5)


In [6]:
tcapts

array([['A man is in a funny position during a tennis match',
        'A tennis player at the net after his play on the court. ',
        'A man near the net playing tennis with official looking on.',
        'a man by a tennis net getting ready to hit a ball',
        'A man is attempting to return the ball'],
       ['A white van is following an orange and white bus down the road. ',
        'A van following behind a bus in the street. ',
        'A white and orange bus driving down a city street.',
        'A van follows behind a bus on a rural road.',
        'A passenger bus that is driving down a street.'],
       ['A group of children sitting around each other.',
        'four children looking at each other one holding long object',
        'A girl holding a tube talking to another girl.',
        'Group of children sitting on a bench petting a dog.',
        'The children are grouped together waiting their turn.'],
       ...,
       ['the elephants are  all next to each other 

# Preprocessing captions

In [7]:
PAD = "#PAD#"
UNK = "#UNK#"
START = "#START#"
END = "#END#"

In [8]:
# split sentence into tokens (split into lowercased words)
def split_sentence(sentence):
    return list(filter(lambda x: len(x) > 0, re.split('\W+', sentence.lower())))

def generate_vocabulary(train_captions):

    w_bag = [PAD, START, UNK, END] + [item.lower() for sublist in [split_sentence(sent) for sent in train_captions] for item in sublist]
    vocab = Counter(w_bag)
    vocab = [PAD, START, UNK, END] + [token for token, cnt in zip(vocab.keys(), vocab.values()) if cnt >= 5]
    vocab = {key: i for i, key in enumerate(vocab)}

    return vocab

def caption_tokens_to_indices(captions, vocab):
    return  [[vocab[START]] + [vocab[word] if word in vocab else vocab[UNK] for word in split_sentence(capt)] + [vocab[END]] for capt in captions]

In [9]:
# prepare vocabulary
vocab = generate_vocabulary(tcapts[:, 3])
vocab_inverse = {idx: w for w, idx in vocab.items()}
print(len(vocab))

5154


In [10]:
# we will use this during training
def batch_captions_to_matrix(batch_captions, pad_idx, max_len=None):
    
    if max_len is None:
        pad_len = max(map(len, batch_captions))
    else:
        pad_len = min(max(map(len, batch_captions)), max_len)
        
    matrix = []
    for capt in batch_captions:
        if pad_len-len(capt) >= 0:
            matrix.append(np.pad(capt, (0, pad_len-len(capt)), mode='constant', constant_values=pad_idx))
        else:
            matrix.append(capt[:pad_len])
    
    return np.array(matrix)

In [11]:
def decode_captions(out):
    softmax = nn.Softmax(dim=1)(out).argmax(axis=1)
    softmax_to_tokens = lambda batch: np.array(list(map(lambda token: vocab_inverse[token], batch.numpy().reshape(-1)))).reshape(*batch.shape)
    return softmax_to_tokens(softmax)

In [12]:
vocab = generate_vocabulary(tcapts[:, 3])
tcapts_enc = batch_captions_to_matrix(caption_tokens_to_indices(tcapts[:, 3], vocab), vocab[PAD], max_len=50)
vcapts_enc = batch_captions_to_matrix(caption_tokens_to_indices(vcapts[:, 3], vocab), vocab[PAD], max_len=50)

In [13]:
tcapts_len = [len(x) for x in caption_tokens_to_indices(tcapts[:, 3], vocab)]
vcapts_len = [len(x) for x in caption_tokens_to_indices(vcapts[:, 3], vocab)]

In [14]:
tcapts_enc.shape

(118287, 49)

In [15]:
with open('tcapts_encoded.pkl', 'wb') as file_embeds:
    pickle.dump(tcapts_enc, file_embeds)
    
with open('vcapts_encoded.pkl', 'wb') as file_capts:
    pickle.dump(vcapts_enc, file_capts)

# Decoder

In [16]:
class RNNModel(nn.Module):
    
    IMG_EMBED_SIZE = tfeats.shape[1]
    IMG_EMBED_BOTTLENECK = 120
    WORD_EMBED_SIZE = 100
    LSTM_UNITS = 300
    LOGIT_BOTTLENECK = 120
    pad_idx = vocab[PAD]

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        
        super(RNNModel, self).__init__()
        
        self.drop = nn.Dropout(dropout)
        
        self.img_embed_to_bottleneck = nn.Linear(self.IMG_EMBED_SIZE, self.IMG_EMBED_BOTTLENECK)
        self.img_embed_bottleneck_to_h0 = nn.Linear(self.IMG_EMBED_BOTTLENECK, nhid)
        
        self.embedding = nn.Embedding(ntoken, ninp)
        
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
            
        self.token_logits_bottleneck = nn.Linear(nhid, self.LOGIT_BOTTLENECK)
        self.token_logits = nn.Linear(self.LOGIT_BOTTLENECK, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
        self.img_embed_to_bottleneck.bias.data.fill_(0)
        self.img_embed_to_bottleneck.weight.data.uniform_(-initrange, initrange)
        
        self.img_embed_bottleneck_to_h0.bias.data.fill_(0)
        self.img_embed_bottleneck_to_h0.weight.data.uniform_(-initrange, initrange)
        
        self.token_logits_bottleneck.bias.data.fill_(0)
        self.token_logits_bottleneck.weight.data.uniform_(-initrange, initrange)
        
        self.token_logits.bias.data.fill_(0)
        self.token_logits.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, input_lengths, hidden=None):
        hidden = self.img_embed_bottleneck_to_h0(self.img_embed_to_bottleneck(hidden))
        
        word_embeds = self.drop(self.embedding(x.T))
        word_embeds = nn.utils.rnn.pack_padded_sequence(word_embeds, batch_len, enforce_sorted=False)
        
        output, hidden = self.rnn(word_embeds, (hidden, hidden))
        output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=x.shape[1])
        output = self.drop(output)

        output_bottleneck = self.token_logits_bottleneck(output.view(output.shape[0]*output.shape[1], output.shape[2]))
        output = self.token_logits(output_bottleneck)
        return output, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()

In [41]:
decoder = RNNModel('LSTM', ntoken=len(vocab), ninp=RNNModel.WORD_EMBED_SIZE, nhid=RNNModel.LSTM_UNITS, nlayers=1, dropout=0.3).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)

  "num_layers={}".format(dropout, num_layers))


In [29]:
import pdb

In [42]:
images_count = tcapts.shape[0]
BATCH_SIZE = 32
BATCH_COUNT = int(np.ceil(images_count/BATCH_SIZE))

EVERY_BATCHES_TOSHOW = 100

EPOCHS = 10

decoder.train()

for e in range(EPOCHS):

    print('EPOCH %d' % (e+1))

    for i in range(BATCH_COUNT):
        capts = tcapts_enc[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :]
        target = torch.Tensor(np.hstack((capts, np.array([vocab[PAD]]*capts.shape[0]).reshape(-1, 1)))[:, 1:]).long()
        im_feats = tfeats[i*BATCH_SIZE:(i+1)*BATCH_SIZE, :]

        batch_len = tcapts_len[i*BATCH_SIZE:(i+1)*BATCH_SIZE]

        capts = torch.Tensor(capts).long()
        im_feats = torch.Tensor(np.expand_dims(im_feats, 0))

        decoder.zero_grad()

        out, hidden = decoder(capts.cuda(), batch_len, im_feats.cuda())
        loss = criterion(out, target.view(-1).cuda())

        if i > 0 and i % EVERY_BATCHES_TOSHOW == 0:
            print("Batch {}: loss: {}".format(i, loss.cpu().detach().numpy()))

        loss.backward()        
        optimizer.step()

EPOCH 1
Batch 100: loss: 5.249606609344482
Batch 200: loss: 1.9148378372192383
Batch 300: loss: 1.8509377241134644
Batch 400: loss: 1.7441773414611816
Batch 500: loss: 1.8146336078643799
Batch 600: loss: 1.7934309244155884
Batch 700: loss: 1.7180540561676025
Batch 800: loss: 1.8004348278045654
Batch 900: loss: 1.8061316013336182
Batch 1000: loss: 1.8162428140640259
Batch 1100: loss: 1.7713053226470947
Batch 1200: loss: 1.8517801761627197
Batch 1300: loss: 1.8685294389724731
Batch 1400: loss: 1.9520533084869385
Batch 1500: loss: 1.7668880224227905
Batch 1600: loss: 1.7999995946884155
Batch 1700: loss: 1.835193395614624
Batch 1800: loss: 1.7093985080718994
Batch 1900: loss: 1.7847278118133545
Batch 2000: loss: 1.8700073957443237
Batch 2100: loss: 1.8455947637557983
Batch 2200: loss: 1.7625147104263306
Batch 2300: loss: 1.7297617197036743
Batch 2400: loss: 1.7644435167312622
Batch 2500: loss: 1.7758479118347168
Batch 2600: loss: 1.7405884265899658
Batch 2700: loss: 1.7635899782180786
Batc

KeyboardInterrupt: 

In [43]:
x = torch.Tensor(tcapts_enc[:30, :]).long().cuda()
batch_len = tcapts_len[:30]
hidden = torch.Tensor(np.expand_dims(tfeats[:30, :], 0)).cuda()

decoder.eval()

out, hid = decoder(x, batch_len, hidden)

### ... и тут что-то пошло не так( у нас одни паддинги

In [44]:
Counter(decode_captions(out.detach().cpu()))

Counter({'#PAD#': 1470})

### ... а попытка сделать инференс вообще ломает куду

In [45]:
x = torch.Tensor([[vocab[START]]]).long()
x.shape

torch.Size([1, 1])

In [46]:
hidden = np.expand_dims(tfeats[0], 0)
hidden = np.expand_dims(hidden, 0)
hidden = torch.tensor(hidden)
hidden.shape

torch.Size([1, 1, 2048])

In [47]:
decoder(x.cuda(), [1], hidden.cuda())

RuntimeError: tabulate: failed to synchronize: cudaErrorAssert: device-side assert triggered