In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Date    : Dec-29-21 15:45
# @Author  : Kan HUANG (kan.huang@connect.ust.hk)

import numpy as np
import torch
import torch.nn as nn

In [None]:
features = np.linspace(-1.5, 0.3, num=(N * D)).reshape(N, D)
captions = (np.arange(N * T) % V).reshape(N, T)

In [84]:
class CaptioningRNN(nn.Module):
    def __init__(self,  word_to_idx, input_dim=512, wordvec_dim=128,
                 hidden_dim=128, cell_type='rnn', dtype=torch.float32):
        super(CaptioningRNN, self).__init__()
        if cell_type not in {'rnn', 'lstm'}:
            raise ValueError('Invalid cell_type "%s"' % cell_type)

        self.cell_type = cell_type
        self.dtype = dtype
        self.word_to_idx = word_to_idx
        self.idx_to_word = {i: w for w, i in word_to_idx.items()}
        self.params = {}

        vocab_size = len(word_to_idx)

        self.embed = nn.Embedding(vocab_size, wordvec_dim)
        self.fc_proj = nn.Linear(input_dim, hidden_dim)
        
        dim_mul = {'lstm': 4, 'rnn': 1}[cell_type]
        if cell_type == "rnn":
            self.rnn = nn.RNN(wordvec_dim, hidden_dim, dim_mul)

        self.fc_vocab = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, captions):
        """forward
        Inputs:
        - features: Input image features, of shape (N, D)
        - captions: Ground-truth captions; an integer array of shape (N, T) 
        """
        captions_in = captions[:, :-1] # (N, T)
        captions_out = captions[:, 1:]

        # Initial hidden state
        h_prev = self.fc_proj(features) # (N, D) -> (N, H)
        h_prev = h_prev.unsqueeze(0) # (N, H) -> (1, N, H)
        # print(f"h_prev.shape: {h_prev.shape}")

        # Use a word embedding layer to transform the words in captions_in from indices to vectors, giving an array of shape (N, T, W).
        word_vectors = self.embed(captions_in) # (N, T) -> (N, T, W)

        # Must transpose first!
        word_vectors = word_vectors.transpose(1, 0) # (N, T, W) -> (T, N, W)
        (T, N, W) = word_vectors.shape

        # print(f"word_vectors.shape: {word_vectors.shape}")
        # print(f"h_prev.shape: {h_prev.shape}")

        # print(f"word_vectors.shape: {word_vectors.shape}")
        # process the sequence of input word vectors and produce hidden state vectors for all timesteps
        for i in range(T):
            # step once
            output, h_next = self.rnn(word_vectors[i].unsqueeze(0), h_prev)
            h_prev = h_next
        
        loss = None

        return loss

In [86]:
N, D, W, H = 10, 20, 30, 40
word_to_idx = {'<NULL>': 0, 'cat': 2, 'dog': 3}
V = len(word_to_idx)
T = 13 # max_length

batch_size = N
timesteps = T
input_dim = D
wordvec_dim = W
hidden_dim = H

model = CaptioningRNN(word_to_idx,
          input_dim=input_dim,
          wordvec_dim=wordvec_dim,
          hidden_dim=hidden_dim,
          cell_type='rnn',
          dtype=torch.float32)
np.random.seed(231)
word_to_idx = {'<NULL>': 0, 'cat': 2, 'dog': 3}
vocab_size = len(word_to_idx)

# captions: int
captions = torch.randint(vocab_size, size=(batch_size, timesteps))
features = torch.randn(batch_size, input_dim)

loss = model(features, captions)

h_prev.shape: torch.Size([1, 10, 40])
word_vectors.shape: torch.Size([12, 10, 30])
h_prev.shape: torch.Size([1, 10, 40])
word_vectors.shape: torch.Size([12, 10, 30])


TypeError: cannot unpack non-iterable NoneType object