In [1]:
import torch
from torch import nn
import numpy as np


device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda:0 device


## LoadData

In [2]:
import torchtext 

train_data = list(torchtext.datasets.IMDB(split='train',root=r'L:\Datasets'))[12500:12600]
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
def yield_tokens(data):
    for (_,text) in data:
        yield tokenizer(text)

vocab = torchtext.vocab.build_vocab_from_iterator(yield_tokens(train_data),specials=['<unk>','<sos>','<eos>'],min_freq=3) ## '<pad>',
vocab.set_default_index(vocab['<unk>'])

In [3]:
def collate_batch_noLable(data_batch):
    text_lst = []
    for _, _text in data_batch:
        tk_text = vocab(['<sos>'] + tokenizer(_text) + ['<eos>'])
        text_lst.append(torch.tensor(tk_text,dtype=torch.int64))
    text_lst = torch.nn.utils.rnn.pad_sequence(text_lst, padding_value=float(vocab['<eos>']) )  ## pad 0 to equal length
    return text_lst.to(device)

test_dl_noLable = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True, collate_fn = collate_batch_noLable) 
## [seq_len, batch_size]

## Pre-train Embedder
主要是为了后续 Retrieve from Embedding Vector；此处只是个玩具，建议使用pre-train的word2vec等模型

In [4]:
class Emb2Class(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)
    def forward(self,inputs):
        emb = self.emb(inputs)
        return self.fc(emb)#.argmax(dim=2)

embed_dim = 300
vocab_size = len(vocab)
Emb2Class_model = Emb2Class(vocab_size, embed_dim).to(device)
Emb2Class_model

Emb2Class(
  (emb): Embedding(1110, 300)
  (fc): Linear(in_features=300, out_features=1110, bias=True)
)

In [5]:
ce = nn.CrossEntropyLoss()
def Emb2Class_lossfn(pred,real):
    loss = 0
    for i in range(real.shape[0]):
        loss += ce(pred[i],real[i])
    return loss/real.shape[0]

In [6]:
Emb2Class_optimizer = torch.optim.SGD(Emb2Class_model.parameters(), lr=1e-3)

def Emb2Class_train(dataloader, model, loss_fn, optimizer):
    lossSum = 0
    model.train()                              ### set training mode
    for inputs in dataloader:
        pred = model(inputs)
        loss = loss_fn(pred,inputs)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print("=".format(epoch+1),end='')
    print("last batch loss:{}".format(loss))

In [7]:
for epoch in range(20):
    print("Epoch:{}".format(epoch+1),end='\t')
    Emb2Class_train(test_dl_noLable, Emb2Class_model, Emb2Class_lossfn, Emb2Class_optimizer)

Epoch:1	=====last batch loss:6.195950508117676
Epoch:2	=====last batch loss:5.4163336753845215
Epoch:3	=====last batch loss:5.043501853942871
Epoch:4	=====last batch loss:4.313214302062988
Epoch:5	=====last batch loss:3.695521593093872
Epoch:6	=====last batch loss:3.0953471660614014
Epoch:7	=====last batch loss:2.9020485877990723
Epoch:8	=====last batch loss:3.0827300548553467
Epoch:9	=====last batch loss:3.065711736679077
Epoch:10	=====last batch loss:2.2469937801361084
Epoch:11	=====last batch loss:2.782947540283203
Epoch:12	=====last batch loss:2.455803632736206
Epoch:13	=====last batch loss:3.3548364639282227
Epoch:14	=====last batch loss:2.3363473415374756
Epoch:15	=====last batch loss:2.288115978240967
Epoch:16	=====last batch loss:3.062293529510498
Epoch:17	=====last batch loss:2.3599071502685547
Epoch:18	=====last batch loss:2.2705986499786377
Epoch:19	=====last batch loss:2.7228190898895264
Epoch:20	=====last batch loss:2.1954903602600098


In [8]:
pre_param_emb = Emb2Class_model.emb.state_dict()

## RNN Encoder-Decoder

```
Encoder:
ht_0 --> RNN_Cell --> ht_1 --> RNN_Cell --> ht_2 --> RNN_Cell --> .... --> RNN_Cell --> ht_end
             ^                    ^                      ^                    ^
         word1_batch             word2                  word3                word_end


Decoder:
ht_0 --> RNN_Cell --> ht_1 --> RNN_Cell --> ht_2 --> RNN_Cell --> .... --> RNN_Cell --> ht_stop <eos>
          ^            |       ^              |       ^
<SOS>   __|            o_1   __|              o_2   __|       也可以每步将 ht_0 + o_t or [o_1,...o_t] 作为输入


GRU/Transformer 同理类似
```

### 参考
https://zhuanlan.zhihu.com/p/80866196

LSTM: https://curow.github.io/blog/LSTM-Encoder-Decoder/

In [9]:
# tt = nn.RNNCell(8,6)
# tt( torch.zeros((1,8)) ,torch.zeros((1,6)) ).shape   ##==>torch.Size([1, 6])

In [10]:
class RNN_encoder(nn.Module):
    def __init__(self, embed_dim, hidden_unitE):
        super().__init__()
        self.rnn_cell = nn.RNNCell(embed_dim,hidden_unitE)
        self.hidden_unitE = hidden_unitE
    def forward(self,inputs):
        htE = torch.zeros((inputs.shape[1],self.hidden_unitE)).to(device)  ## (batch_size,hidden_unitE)
        for word in inputs:
            # if len(torch.nonzero(word-vocab['eos']))== 0:      ## when all words == vocab['eos']
            #     break
            htE = self.rnn_cell(word,htE)
        return htE

class RNN_decoder(nn.Module):
    def __init__(self, embed_dim, hidden_unitD):
        super().__init__()
        self.rnn_cell = nn.RNNCell(embed_dim, hidden_unitD)
        self.fc = nn.Linear(hidden_unitD, embed_dim)         ## output word's embedding vector
    def forward(self, inputs, htD):
        out_lst = []
        batch_size = htD.shape[0]
        for word in inputs:
            htD = self.rnn_cell(word,htD)           ## htD:(batch_size,hidden_unitD)
            out = self.fc(htD)   
            out_lst.append(out)
        return torch.stack(out_lst),htD

In [11]:
class RNN_AE_Net(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_unitE, hidden_unitD):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim)           ## 可以先预训练一下？
        self.rnnEncoder = RNN_encoder(embed_dim, hidden_unitE)
        self.fc_E2D = nn.Linear(hidden_unitE, hidden_unitD)
        self.rnnDecoder = RNN_decoder(embed_dim, hidden_unitD)
    def forward(self,inputs):
        emb = self.emb(inputs)
        x = self.rnnEncoder(emb)
        x = self.fc_E2D(x)
        x,htD = self.rnnDecoder(emb,x)
        return x.to(device),htD.to(device)


embed_dim = 300
hidden_unit = 200
vocab_size = len(vocab)

model = RNN_AE_Net(vocab_size, embed_dim, hidden_unit, hidden_unit).to(device)
model

RNN_AE_Net(
  (emb): Embedding(1110, 300)
  (rnnEncoder): RNN_encoder(
    (rnn_cell): RNNCell(300, 200)
  )
  (fc_E2D): Linear(in_features=200, out_features=200, bias=True)
  (rnnDecoder): RNN_decoder(
    (rnn_cell): RNNCell(300, 200)
    (fc): Linear(in_features=200, out_features=300, bias=True)
  )
)

In [12]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model.emb.load_state_dict(pre_param_emb,strict=False)
model.emb.requires_grad = False

mse =  nn.MSELoss()
def loss_fn(pred,real):
    loss = 0
    for i in range(real.shape[0]):
        for j in range(real.shape[1]):
            loss += mse(pred[i,j,:],real[i,j,:])           ## loss for each word
    return loss/(real.shape[0]*real.shape[1])

def train(dataloader, model, loss_fn, optimizer):
    lossSum = 0
    model.train()                              ### set training mode
    for inputs in dataloader:
        pred,htD = model(inputs)
        real = model.emb(inputs)               ### torch.Size([seqlen, batch, embedsize])  otherwise .transpose(0, 1)
        loss = loss_fn(pred,real)
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print("=".format(epoch+1),end='')
    print("last batch loss:{}".format(loss))

In [13]:
for epoch in range(20):
    print("Epoch:{}".format(epoch+1),end='\t')
    train(test_dl_noLable, model, loss_fn, optimizer)

Epoch:1	=====last batch loss:1.1063904762268066
Epoch:2	=====last batch loss:1.1015305519104004
Epoch:3	=====last batch loss:1.0962250232696533
Epoch:4	=====last batch loss:1.0912203788757324
Epoch:5	=====last batch loss:1.0891518592834473
Epoch:6	=====last batch loss:1.086581826210022
Epoch:7	=====last batch loss:1.0821523666381836
Epoch:8	=====last batch loss:1.078304409980774
Epoch:9	=====last batch loss:1.0785313844680786
Epoch:10	=====last batch loss:1.070932388305664
Epoch:11	=====last batch loss:1.0637710094451904
Epoch:12	=====last batch loss:1.0724480152130127
Epoch:13	=====last batch loss:1.0683114528656006
Epoch:14	=====last batch loss:1.0558257102966309
Epoch:15	=====last batch loss:1.0469310283660889
Epoch:16	=====last batch loss:1.0455999374389648
Epoch:17	=====last batch loss:1.0410704612731934
Epoch:18	=====last batch loss:1.0318052768707275
Epoch:19	=====last batch loss:1.0438741445541382
Epoch:20	=====last batch loss:1.040648102760315


## Retrieve from Embedding Vector

?? 似乎完全没有用

In [14]:
for inputs in test_dl_noLable:
    break

In [15]:
emb_out,_ = model(inputs)
pred_out = Emb2Class_model.fc(emb_out).argmax(dim=2)

In [17]:
str = ""
for ii in inputs[:,0].cpu().numpy():
    str += vocab.lookup_token(ii)
    str += " "
str

'<sos> i liked this movie a lot . it really intrigued me how deanna and alicia <unk> friends over such a tragedy . alicia was just a <unk> soul and deanna was so happy just to see someone after being shot . my only <unk> was that in the beginning it was kind of slow and it took <unk> to get to the <unk> of things . other than that it was great . <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos

In [19]:
str = ""
for ii in pred_out[:,0].cpu().numpy():
    str += vocab.lookup_token(ii)
    str += " "
str

'von ones what society equally favorite answers <eos> incredibly hidden horror knew entertainment live storyline describe interest imdb soon plays superb apart <eos> plays choice twists lars record lighting storyline live lighting fight tricks scene walter bed wave world investigating sides <eos> after main was lighting , lost tells much incredibly choice helps busy fantastic storyline incredibly camera recommended walter doll best believe recommended stop town <eos> open knew company incredibly john script <eos> sorry often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often often o

In [20]:
ii

612

In [21]:
vocab['<eos>']

2

In [22]:
vocab.lookup_token(vocab['<eos>'])

'<eos>'