# 1 准备数据
## 1.1下载数据集 
    运行 ./download/sh
## 1.2 Resize图片
    运行 python resize.py 将所有图片裁剪至指定大小
## 1.3 Vocab
    运行 python build_vocab

In [6]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
from data_loader import get_loader
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_root = '/Volumes/SD/Dataset/'

In [7]:
model_path = 'models/'              # path for saving trained models'
crop_size = 224                       # size for randomly cropping images'
vocab_path = dataset_root+'coco/vocab.pkl'   # path for vocabulary wrapper'
image_dir = dataset_root+'coco/resized2014'  # directory for resized images'
caption_path = dataset_root+'coco/annotations/captions_train2014.json'     # path for train annotation json file'
log_step = 10                           # step size for prining log info'
save_step = 1000                     # step size for saving trained models'

# Model parameter
# embed_size = 256     # dimension of word embedding vectors'
# hidden_size = 512      # dimension of lstm hidden states'
num_layers = 1          # number of layers in lstm'

num_epochs = 5
batch_size = 128
num_workers = 2
learning_rate = 0.001

In [8]:
# Create model directory
if not os.path.exists(model_path):
    os.makedirs(model_path)

# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([
    transforms.RandomCrop(crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))])

# Load vocabulary wrapper
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)

# Build data loader, 读取训练数据
data_loader = get_loader(image_dir, caption_path, vocab,
                         transform, batch_size,
                         shuffle=True, num_workers=num_workers)

loading annotations into memory...
Done (t=1.49s)
creating index...
index created!


In [24]:
len(vocab)

9956

In [22]:
# Build the models
encoder = EncoderCNN(embed_size=256).to(device)  # encode描述了物体, 位置, 动作等综合信息
decoder = DecoderRNN(embed_size=256, hidden_size=512, len(vocab), num_layers).to(device)  # decode将encoder包含的信息翻译成文字描述

In [10]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

In [12]:
# Train the models
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, captions, lengths) in enumerate(data_loader):

        # Set mini-batch dataset
        # 读取Image和Caption
        images = images.to(device)    # torch.Size([128, 3, 224, 224])
        captions = captions.to(device)   # torch.Size([128, 23])
        # 将Caption转为target  便于训练
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]   # torch.Size([1673])

        # Forward  backward and optimize
        # 输入image, captions, length用来训练模型
        features = encoder(images)   # torch.Size([128, 256])
        # outputs是模型对当前图片的caption预测
        outputs = decoder(features, captions, lengths)   # torch.Size([1673, 9956])

        # Backward
        # 计算label和outputs之间的损失
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item())))

        # Save the model checkpoints
        if (i+1) % save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(encoder.state_dict(), os.path.join(
                model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
        
        break
    break


Epoch [0/5], Step [0/3236], Loss: 9.2004, Perplexity: 9900.9128


In [67]:
pack_padded_sequence(captions, lengths, batch_first=True)

PackedSequence(data=tensor([ 1,  1,  1,  ..., 40, 19,  2]), batch_sizes=tensor([128, 128, 128, 128, 128, 128, 128, 128, 128, 127, 115,  94,  67,  44,
         31,  20,  11,   5,   2,   2,   1,   1,   1]), sorted_indices=None, unsorted_indices=None)

## Decoder过程详解

In [83]:
print(captions.size())
captions


torch.Size([128, 23])


128

## 1.1 将Caption编码, 用256维的embedding表示

In [57]:
# 将caption用256维的embedding表示
embeddings = decoder.embed(captions)
print(embeddings.size())
print(features.unsqueeze(1).size())

torch.Size([128, 23, 256])
torch.Size([128, 1, 256])


In [58]:
# 拼接embedding和features
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
embeddings.size()

torch.Size([128, 24, 256])

In [69]:
packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
packed

PackedSequence(data=tensor([[ 0.3116, -0.4476, -0.2999,  ..., -1.0705,  0.7326,  1.0773],
        [ 0.0202,  0.0738,  0.6870,  ...,  0.8710, -0.6984,  1.4274],
        [ 0.0242, -0.3649,  1.1551,  ..., -1.2311,  0.7355, -1.1059],
        ...,
        [-0.3794,  0.4380,  0.8859,  ..., -0.8869, -1.2137, -0.2564],
        [ 1.3821,  0.4445, -0.5598,  ...,  0.6688, -0.3859, -0.2020],
        [ 1.9841,  1.4484,  1.4450,  ...,  1.1245,  0.8445, -1.0468]],
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([128, 128, 128, 128, 128, 128, 128, 128, 128, 127, 115,  94,  67,  44,
         31,  20,  11,   5,   2,   2,   1,   1,   1]), sorted_indices=None, unsorted_indices=None)

In [73]:
# torch.nn.utils.rnn.pad_packed_sequence(packed, batch_first = True)

In [100]:
hiddens, _ = decoder.lstm(packed)
outputs = decoder.linear(hiddens[0])

torch.Size([9956])

In [97]:
targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
targets

tensor([ 1,  1,  1,  ..., 40, 19,  2])

In [95]:
sum(lengths)
loss = criterion(outputs, targets)
