In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from common.components import LayerNorm
from IPython.core.debugger import set_trace
import os

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# 长度固定512不合理，#处再想想其他处理

In [11]:
class CLNEncoder(nn.Module):
    def __init__(self, max_seq_len, voc_size, hidden_size, layers, fake_input):
        super().__init__()
        self.layers = layers
        self.max_seq_len = max_seq_len
        # word embedding
        self.word_emb = nn.Embedding(voc_size, hidden_size)
        # position fc
        self.context_fc = nn.Linear(max_seq_len, 1) # 
        # conditional layer normalization
        self.cln = LayerNorm(fake_input.size(), fake_input.size()[-1], conditional = True)
        # position embedding
        self.position_emb = torch.zeros([max_seq_len, hidden_size]).to(fake_input.device)
        for d in range(max_seq_len):
            for i in range(hidden_size):
                if i % 2 == 0:
                    self.position_emb[d][i] = math.sin(d / 10000**(i / hidden_size))
                else:
                    self.position_emb[d][i] = math.cos(d / 10000**((i - 1) / hidden_size))
                    
    def forward(self, input_ids):
#         # padding
#         # input_ids (batch_size, seq_len) -> (batch_size, max_seq_len)
#         input_ids = F.pad(input_ids, (0, self.max_seq_len - input_ids.size()[-1]))
        seq_len = input_ids.size()[1]
        assert input_ids.size()[1] <= self.max_seq_len
    
        # input_embbedings (batch_size, seq_len, hidden_size)
        input_embbedings = self.word_emb(input_ids)
        
        input_hiddens = input_embbedings
        for i in range(self.layers):
            # position embbedding
            input_hiddens = input_hiddens + self.position_emb[None, :seq_len, :].repeat(input_hiddens.size()[0], 1, 1)
            # context: (batch_size, hidden_size)
            context = torch.tanh(self.context_fc(input_hiddens.permute(0, 2, 1))).view(input_hiddens.size()[0], input_hiddens.size()[2])
            # repeat_context: (batch_size, seq_len, hidden_size)
            repeat_context = context[:,None,:].repeat(1, input_hiddens.size()[1], 1)
            input_hiddens = self.cln(input_hiddens, repeat_context)
        
        output = input_hiddens
        return output

In [12]:
batch_size = 32
max_seq_len = 512
hidden_size = 768
voc_size = 30000
layer_num = 12
fake_input = torch.zeros([batch_size, max_seq_len, hidden_size]).to(device)
encoder = CLNEncoder(max_seq_len, voc_size, hidden_size, layer_num, fake_input).to(device)

In [10]:
inputs = torch.randint(voc_size, (48, 128)).to(device)
output = encoder(inputs)
print(output.size())

torch.Size([48, 512, 768])
