In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """
        Load the pretrained ResNet-152 and replace top fc layer.
        基于ResNet, 原本有fc层2048->1000, 现在删掉
        改为 
        fc: 2048->256 
        bn层
        """
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer. 
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


In [5]:
encode = EncoderCNN(256)


In [6]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        # 对9956个词进行编码, 每个词256维
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """
        Decode image feature vectors and generates captions.
        Decoder的作用是, 把Resnet处理过的256维的feature, 和其caption拼接起来作为输入
        :param length: 是一个[128,] 列表表示一个batch中, 每个句子的长度
        :return outputs, 返回一个长度为1673(pack过的所有词)的预测
        """
        # 将caption[128, 23]转为用256维的embeddings表示
        # 把ResNet处理过的feature和embedding拼接, 作为统一输入

        embeddings = self.embed(captions)   # [128, length, 256]    这里的length是最长句子的长度, 其中个句子可能是[1, 15, 256], 16-23位pad补零
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)  # [128, 1, 256] 和 [128, 23, 256]拼接成 [128, 24, 256]
        
        # 由于句子的长度不一, 总长度为23, 短的句子结尾有很多0, 为了避免影响, 这里要pack起来
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])   # 所有状态的最后一位是预测 [1673, 9956]
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        
        # 只输出前20个词
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

In [8]:
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=9956, num_layers=1)

NameError: name 'captions' is not defined