In [1]:
from data_loader import get_loader
from build_vocab import Vocabulary
from torchvision import transforms
import pickle

train_annote_path = './data/annotations/train/captions.json'
train_image_path = './data/images/train'
batch_size = 5
with open('./data/vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)
transform = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

data_loader = get_loader(train_annote_path, train_image_path, vocab, transform,
                        batch_size,shuffle=True, num_workers=2)

In [2]:
for (images, captions, lengths) in data_loader:
    break

print(captions[0, 1:])
captions[0]

tensor([   4,   76,   62, 2408,    4,    3,    3,   17,  375,  961,   12,    2])


tensor([   1,    4,   76,   62, 2408,    4,    3,    3,   17,  375,  961,   12,
           2])

In [3]:
vocab_size = len(vocab)
embed_size=512
feature_num=512
hidden_size=512
attention_dim=512

In [4]:
from model import EncoderCNN

encoder = EncoderCNN()
features = encoder(images)
print(images.shape, features.shape)

torch.Size([5, 3, 224, 224]) torch.Size([5, 512, 196])


In [59]:
import torch
h = torch.zeros(5, 512)
c = torch.zeros(5, 512)

In [47]:
from model import Attention
from torch import nn

embed = nn.Embedding(vocab_size, embed_size)
attention = Attention(feature_num, hidden_size, attention_dim)
f_beta = nn.Linear(hidden_size, feature_num)
sigmoid = nn.Sigmoid()
lstm = nn.LSTMCell(embed_size + feature_num, hidden_size, bias=True)
# input/output: (batch, seq, feature)
# Todo: LSTM hidden layer >= 2 인 경우
dropout = nn.Dropout(p=0.5)
# 논문에서 regularizer로 dropout과 BLEU score early stopping 사용
linear = nn.Linear(hidden_size, vocab_size)

In [55]:
batch_size = features.shape[0]
feature_dim = features.shape[2]

embedding = embed(captions) # size: caption_len*embed_size
print(embedding.shape)
decode_lengths = [len - 1 for len in lengths]
print(len(decode_lengths))

predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size)
alphas = torch.zeros(batch_size, max(decode_lengths), feature_dim)
print(predictions.shape, alphas.shape)

torch.Size([5, 13, 512])
5
torch.Size([5, 12, 2699]) torch.Size([5, 12, 196])


In [56]:
batch_idx = 4
z, alpha = attention(features[:batch_idx], h[:batch_idx])
print(z.shape, alpha.shape)
z = z * sigmoid(f_beta(h[:batch_idx]))
z.shape

torch.Size([4, 512]) torch.Size([4, 196])


torch.Size([4, 512])

In [60]:
for i in range(max(decode_lengths)):
    batch_idx = sum([l > i for l in decode_lengths])
    z, alpha = attention(features[:batch_idx], h[:batch_idx])
    z = z * sigmoid(f_beta(h[:batch_idx]))
    # 논문에서, attention output z에 gating scalar beta를 곱함으로써 성능 상승(4.2.1절)

    lstm_input = torch.cat([embedding[:batch_idx, i, :], z[:batch_idx]], dim=1)
    h, c = lstm(lstm_input, (h[:batch_idx], c[:batch_idx]))
    output = linear(dropout(h))
    print(output.shape)
    predictions[:batch_idx, i, :] = output
    alphas[:batch_idx, i, :] = alpha

torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([5, 2699])
torch.Size([3, 2699])
torch.Size([2, 2699])
