# Load pretrained weights to T2I Text_Encoder

In [7]:
import torch
import torch.nn as nn
from data_loader import Text_Dataset
from text_encoder import Text_Encoder as LM
import trainer

# Text Dataset to get the number of whole words
hr_dataset = Text_Dataset(data_dir='data/bird/',
                            split='train',
                            words_num=15,
                            print_shape=False)
                            
# hyper parameters same to T2I code
n_word = hr_dataset.n_word
print('n_word : ', n_word)
embedding_dim = 1024
hidden_size = 1024
n_layers = 1
dropout = .5
max_length = 15

Load filenames from: data/bird//filenames/train/filenames.pickle (8855)
Load filenames from: data/bird//filenames/val/filenames.pickle (2933)
Load from:  data/bird/captions.pickle
n_word :  5450


In [2]:
# T2I's Text_Encoder Class
class Text_Encoder(nn.Module):
    def __init__(self, vocab_size, rnn_type):
        super(Text_Encoder, self).__init__()

        self.rnn_type = rnn_type
        self.vocap_size = vocab_size
        self.embedding_dim = 1024 #embedding size
        self.drop_rate = 0.5 #dropout rate
        self.hidden_dim = 1024 # word dim
        self.num_layers = 1
        self.bidirectional = True # bidirectional option

        if self.bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1

        self.hidden_dim = self.hidden_dim // self.num_directions
        self.embedding_layer = nn.Embedding(num_embeddings=self.vocap_size, embedding_dim=self.embedding_dim)
        self.dropout = nn.Dropout(self.drop_rate)

        if self.rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size = self.embedding_dim,
                               hidden_size=self.hidden_dim,
                               num_layers=self.num_layers,
                               batch_first=True,
                               dropout=self.drop_rate,
                               bidirectional=self.bidirectional)
        else:
            self.rnn = nn.GRU(input_size = self.embedding_dim,
                              hidden_size=self.hidden_dim,
                              num_layers=self.num_layers,
                              batch_first=True,
                              dropout=self.drop_rate,
                              bidirectional=self.bidirectional)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.num_layers * self.num_directions, batch_size, self.hidden_dim).zero_(),
                    weight.new(self.num_layers * self.num_directions, batch_size, self.hidden_dim).zero_())
        else:
            return weight.new(self.num_layers * self.num_directions, batch_size, self.hidden_dim).zero_()

    def forward(self, captions, cap_lens, hidden, mask=None):
        embed = self.embedding_layer(captions)
        embed = self.dropout(embed)

        cap_lens = cap_lens.data.tolist()
        embed = pack_padded_sequence(embed, cap_lens, batch_first=True)

        out, hidden = self.rnn(embed, hidden)
        out = pad_packed_sequence(out, batch_first=True)[0]
        words_emb = out.transpose(1, 2)

        if self.rnn_type == 'LSTM':
            sentence_emb = hidden[0].transpose(0, 1).contiguous()
        else:
            sentence_emb = hidden.transpose(0, 1).contiguous()
        sentence_emb = sentence_emb.view(-1, self.hidden_dim * self.num_directions)

        return words_emb, sentence_emb

In [15]:
model = Text_Encoder(vocab_size=5450, rnn_type='LSTM')

In [24]:
load_model = torch.load('basic.14.0.06-1.06.0.15-1.16.pt')['model']

In [25]:
load_model

OrderedDict([('embedding_layer.weight',
              tensor([[ 2.2491e-01,  1.2387e+00, -1.6484e+00,  ..., -1.8124e-01,
                        8.2356e-01, -5.2841e-01],
                      [ 6.9791e-02,  6.4899e-03,  6.0361e-01,  ...,  1.0983e+00,
                        1.2417e+00,  9.0354e-04],
                      [ 1.1542e+00,  1.7226e-02,  3.0986e-01,  ..., -5.2637e-01,
                        9.2248e-01, -5.5523e-01],
                      ...,
                      [ 2.7472e-01, -1.7736e+00, -8.7598e-01,  ...,  1.5634e+00,
                       -2.0526e+00,  7.4461e-01],
                      [-9.0740e-01, -2.1600e+00, -5.4583e-01,  ..., -1.7708e+00,
                        1.0711e+00,  4.1583e-01],
                      [ 2.8366e-01, -2.9290e-02, -8.6192e-01,  ...,  9.7635e-01,
                       -1.4004e+00, -7.4893e-01]], device='cuda:0')),
             ('out.weight',
              tensor([[ 0.0157,  0.0038, -0.0051,  ...,  0.0176, -0.0093,  0.0066],
               

In [26]:
model.load_state_dict(load_model, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['out.weight', 'out.bias'])