In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as vmodels
from torch.utils.data import DataLoader
import random
import nltk

USE_CUDA = torch.cuda.is_available()


## 데이터 로딩 

In [4]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

cap = dset.CocoCaptions(root = '../../data/CocoCaption/val2014/',
                        annFile = '../../data/CocoCaption/captions_val2014.json',
                        transform=transforms.Compose([transforms.ToTensor(),normalize])

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

loading annotations into memory...
Done (t=0.39s)
creating index...
index created!
Number of samples:  40504
Image Size:  torch.Size([3, 224, 224])
['A group of men sitting around a table eating food.', 'Men sitting at a table smiling while eating food.', 'A group of people enjoying a pot-luck dinner.', 'A person at a table with some food.', 'Two men sampling chili and beer in a tent.']


In [138]:
train_loader = DataLoader(cap,batch_size=3,shuffle=True)

In [139]:
for batch in train_loader:
    break

## Build Vocab 

In [41]:
sents=[]
for value in cap.coco.anns.values():
    sents.append(nltk.word_tokenize(value['caption']))
    
vocab = list(set([w for s in sents for w in s]))

word2index={'<pad>' : 0, '<unk>' : 1, '<s>' : 2, '</s>' :3}
for vo in vocab:
    if word2index.get(vo)==None:
        word2index[vo]=len(word2index)

In [74]:
def prepare_sequence(seq, to_index):
    idxs = list(map(lambda w: to_index[w] if w in to_index.keys() else to_index["<unk>"], seq))
    return Variable(torch.LongTensor(idxs))

In [147]:
def prepare_batch(batch,word2index):
    x,y = batch
    x = torch.cat([xx.unsqueeze(0) for xx in x])
    y = random.choice(y)
    y = [prepare_sequence(nltk.word_tokenize(yy),word2index).view(1,-1) for yy in y]
    max_y = max([s.size(1) for s in y])
    y_p=[]
    for i in range(len(y)):
        if y[i].size(1)<max_y:
            y_p.append(torch.cat([y[i],Variable(torch.LongTensor([word2index['<pad>']]*(max_y-y[i].size(1)))).view(1,-1)],1))
        else:
            y_p.append(y[i])
        
    y = torch.cat(y_p)
    
    return x,y

In [148]:
x,y = prepare_batch(batch,word2index)

## VGGNET16 

In [10]:
vggnet = vmodels.vgg16(pretrained=True)

In [13]:
feature_extractor = nn.Sequential(*(list(vggnet.features)[:-1]))

In [18]:
feature = feature_extractor(Variable(img.unsqueeze(0)))

In [21]:
feature = feature.view(1,512,196).transpose(1,2)

In [26]:
from attention import Attention

## Decoder 

In [117]:
class Decoder(nn.Module):
    def __init__(self,V,E,H,sos_idx,max_len=15):
        super(Decoder,self).__init__()
        
        self.hidden_size = H
        self.max_len = max_len
        self.sos_idx = sos_idx
        self.init_F = nn.Linear(H,H)
        self.embed = nn.Embedding(V,E)
        self.gru = nn.GRU(E+H,H,batch_first=True)
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(2*H,V)
        self.attention = Attention(H,'general') # 어텐션
        
    def start_token(self,batch_size):
        sos = Variable(torch.LongTensor([self.sos_idx]*batch_size)).unsqueeze(1)
        if USE_CUDA: sos = sos.cuda()
        return sos
       
    def init_hidden(self,encoder_hiddens):
        mean_hidden = torch.mean(encoder_hiddens,1)
        hidden = self.init_F(mean_hidden) # B,H
        return hidden.unsqueeze(0) # 1,B,H
        
    def forward(self,encoder_hiddens, max_len=None):
        """
        encoder_hiddens : B,196,512 (CNN에서 뽑아낸 Visual features)
        """
        if max_len is None: max_len = self.max_len
        batch_size = encoder_hiddens.size(0)
        
        inputs = self.start_token(batch_size) # Batch_size
        hidden = self.init_hidden(encoder_hiddens)
        embed = self.embed(inputs)
#         embed= self.dropout(embed)
        scores=[]
        attn_weights=[]
        for _ in range(max_len):
            
            # context vector 계산
            context = self.attention(hidden.transpose(0,1), encoder_hiddens)
#             attn_weights.append(attn_weight.squeeze(1))
            
            # concat해서 rnn에
            rnn_input = torch.cat([embed,context],2)
            _, hidden = self.gru(rnn_input,hidden)
            
            # concat해서 linear에
            concated = torch.cat([hidden.transpose(0,1),context],2)
            score = self.linear(concated.squeeze(1))
            scores.append(score)
            decoded = score.max(1)[1] # greedy
            embed = self.embed(decoded).unsqueeze(1) # y_{t-1}
#             embed = self.dropout(embed)
            
        #  column-wise concat, reshape!!
        scores = torch.cat(scores,1)
        return scores.view(batch_size*max_len,-1) #, torch.cat(attn_weights)

In [118]:
decoder = Decoder(len(word2index),100,512,word2index['<s>'])

In [149]:
for batch in train_loader:
    x,y = prepare_batch(batch,word2index)
    feature_extractor.zero_grad()
    decoder.zero_grad()
    features = feature_extractor(Variable(x))
    features = features.view(-1,512,196).transpose(1,2)
    preds = decoder(features,y.size(1))
    break