In [None]:
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim

import os
import tensorflow as tf
from IPython.display import clear_output

In [None]:
%load_ext tensorboard

In [None]:

EPOCHS =100
BATCH_SIZE = 32
LR = 2e-5
D_MODEL = 768
DFF = 2048
N_HEAD = 12
N_LAYER = 6
SAVE_EVERY_EPOCHS = 5
INFER_EVERY_EPOCHS = 20
INFER_MAX_LEN = 50
LOG_DIR = './logs/'
OUTPUT_PATH = r'./out_weight'

if not os.path.exists(OUTPUT_PATH):
    os.mkdir(OUTPUT_PATH)

In [None]:
word2id, id2word = json.load(open('./dict_data.json', 'r'))
vocab_size = len(word2id)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(vocab_size)
print(device)

In [None]:
def make_data(datas):
    train_datas = []
    for data in datas:
        data = data.strip()
        train_data = [i if i != '\t' else "<sep>" for i in data]+["<sep>"]
        train_datas.append(train_data)

    train_data_num = [[word2id[word] for word in line] for line in train_datas]
    return train_data_num

class BuildDataset(Dataset):
    def __init__(self, datas):
        self.datas = datas

    def __getitem__(self, idx):
        data = self.datas[idx]
        decoder_inp = data[:-1]
        decoder_out = data[1:]
        return (decoder_inp, decoder_out)
    
    def __len__(self):
        return len(self.datas)
    
    def padding_batch(self, batch):
        inp_len = list(map(lambda i:len(i[0]), batch))
        out_len = list(map(lambda i:len(i[1]), batch))

        max_inp_len = max(inp_len)
        max_out_len = max(out_len)

        for i, data in enumerate(batch):
            data[0].extend([word2id["<pad>"]]*(max_inp_len - inp_len[i]))
            data[1].extend([word2id["<pad>"]]*(max_out_len - out_len[i]))
        
        decoder_inp = torch.LongTensor([d[0] for d in batch])
        decoder_out = torch.LongTensor([d[1] for d in batch])
        return decoder_inp, decoder_out


In [None]:
def get_attn_pad_mask(seq_q, seq_k):
    len_q = seq_q.size(1)
    len_k = seq_k.size(1)
    pad_mask = seq_k.data.eq(0).unsqueeze(1)
    return pad_mask.expand(seq_q.size(0), len_q, len_k)

def get_attn_subsequence_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask).to(device)
    return subsequence_mask


def ScaleDotProduct(q, k, v, attn_mask=None):
    d_k = k.size(-1)
    matmul_qk = torch.matmul(q, k.transpose(-1, -2)/np.sqrt(d_k))
    if attn_mask is not None:
        matmul_qk.masked_fill_(attn_mask, -1e9)

    attn_weights = nn.Softmax(dim=-1)(matmul_qk)
    out = torch.matmul(attn_weights, v)
    return out, attn_weights

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_heads = d_model // n_heads

        self.w_q = nn.Linear(self.d_model, self.d_heads * self.n_heads, bias=False)
        self.w_k = nn.Linear(self.d_model, self.d_heads * self.n_heads, bias=False)
        self.w_v = nn.Linear(self.d_model, self.d_heads * self.n_heads, bias=False)
        self.out_linear = nn.Linear(self.d_model, self.d_model, bias=False)
        self.layer_norm = nn.LayerNorm(self.d_model)
    
    def forward(self, q, k, v, attn_mask=None):
        residual, batch_size = q, q.size(0)

        Q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_heads).transpose(1, 2)
        K = self.w_q(k).view(batch_size, -1, self.n_heads, self.d_heads).transpose(1, 2)
        V = self.w_q(v).view(batch_size, -1, self.n_heads, self.d_heads).transpose(1, 2)
        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        
        context, attn_map = ScaleDotProduct(Q, K, V, attn_mask)
        context = context.permute(0, 2, 1, 3)
        context = context.reshape(batch_size, -1, self.n_heads * self.d_heads)

        out = self.out_linear(context)
        return self.layer_norm(out + residual), attn_map

class FFN(nn.Module):
    def __init__(self, d_model, dff):
        super(FFN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, dff, bias=False),
            nn.ReLU(),
            nn.Linear(dff, d_model, bias=False)
        )
        self.layer_norm = nn.LayerNorm(d_model)
    def forward(self, inp):
        residual = inp
        out = self.fc(inp)
        return self.layer_norm(residual + out)


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, dff, n_head):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head)
        self.ffn = FFN(d_model, dff)
    
    def forward(self, x, attn_mask):
        dec_out, attn_map = self.self_attn(x, x, x, attn_mask)
        dec_out = self.ffn(dec_out)
        return dec_out, attn_map

class Decoder(nn.Module):
    def __init__(self, d_model, dff, n_head, n_layer):
        super(Decoder, self).__init__()

        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(300, d_model)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, dff, n_head)\
                  for _ in range(n_layer)
        ])
    
    def forward(self, inp):
        seq_len = inp.size(1)
        pos = torch.arange(seq_len).to(device)
        pos = pos.unsqueeze(0).expand_as(inp)

        dec_inp = self.emb(inp) + self.pos_emb(pos)

        attn_pad_mask = get_attn_pad_mask(inp, inp)
        attn_subsequence_mask = get_attn_subsequence_mask(inp)
        dec_total_mask = torch.gt((attn_pad_mask+attn_subsequence_mask), 0)

        attn_maps = []
        for layer in self.layers:
            dec_inp, dec_attn_map = layer(dec_inp, dec_total_mask)
            attn_maps.append(dec_attn_map)
        
        return dec_inp, attn_maps

class GPT(nn.Module):
    def __init__(self, d_model, dff, n_head, n_layer):
        super(GPT, self).__init__()
        self.decoder = Decoder(d_model, dff, n_head, n_layer)
        self.proj = nn.Linear(d_model, vocab_size)
    
    def forward(self, inp):
        dec_out, dec_attn = self.decoder(inp)
        dec_logit = self.proj(dec_out)
        return dec_logit.view(-1, vocab_size), dec_attn
    
    def greedy_decode(self, dec_inp, max_len):
        start_len = len(dec_inp[0])
        
        while True:
            if len(dec_inp[0]) - start_len > max_len:
                next_word = word2id['<sep>']
                dec_inp = torch.cat(
                    [dec_inp.detach(), torch.LongTensor([[next_word]]).to(device)], -1
                )
                break
            dec_out, _ = self.decoder(dec_inp)
            dec_proj = self.proj(dec_out)
            prob = dec_proj.squeeze(0).max(dim=-1)[1]
            next_word = prob.data[-1]

            dec_inp = torch.cat(
                [dec_inp.detach(), torch.LongTensor([[next_word]]).to(device)], -1
            )
            
            if next_word == word2id['<sep>']:
                break
        return dec_inp

    def answer(self, sentence, max_len):
        dec_inp = [word2id.get(word, 1) if word !='\t' else \
                   word2id['<sep>'] for word in sentence]
        dec_inp = torch.LongTensor(dec_inp).to(device).unsqueeze(0)

        out = self.greedy_decode(dec_inp, max_len).squeeze()
        out = [id2word[int(i)] for i in out]

        sep_index = []
        for i in range(len(out)):
            if out[i] == '<sep>':
                sep_index.append(i)
        print(sep_index)
        answer = out[sep_index[-2]+1:-1]
        answer = "".join(answer)
        return answer

In [None]:
with open("./text_data.txt", 'r', encoding='utf-8') as f:
    datas = f.readlines()

dataset = BuildDataset(make_data(datas))
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, \
                         collate_fn=dataset.padding_batch, shuffle=True)

model = GPT(D_MODEL, DFF, N_HEAD, N_LAYER).to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
summary_Writer = tf.summary.create_file_writer(LOG_DIR)
model.train()

In [None]:
import time
from tqdm import tqdm

avg_losses = []
out_sentences = []
for epoch in range(1, EPOCHS+1):
    start = time.time()
    losses = []

    for i, (inp_seq, tgt_seq) in enumerate(tqdm(data_loader)):
        
        optimizer.zero_grad()
        inp_seq, tgt_seq = inp_seq.to(device), tgt_seq.to(device)
        dec_out, attn_map = model(inp_seq)
        loss = criterion(dec_out, tgt_seq.view(-1))
        losses.append(loss.item())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        
        if i > 3000:
            break

    avg_loss =np.mean(losses)
    avg_losses.append(avg_loss)
    clear_output()
    
    with summary_Writer.as_default():
        tf.summary.scalar('loss', avg_loss, step=epoch)
    
    if epoch % SAVE_EVERY_EPOCHS == 0:
        end = time.time()
        print("curr loss: {:.3f}".format(avg_loss))
        print("curr epoch: {}, time: {:.2f} min".format(epoch, (end-start)/60))
        
        torch.save(model.state_dict(), \
                os.path.join(OUTPUT_PATH, \
                "epoch_{}_loss_{:.3f}.dat".format(epoch, avg_loss)))  

    if epoch % INFER_EVERY_EPOCHS == 0:
        inp_temp = "你好～～" + "\t"
        model_out = model.answer(inp_temp, INFER_MAX_LEN)
        out_sentences.append(model_out)
        print("chatbot:", model_out)
              
    end = time.time()
    print("curr epoch: {}, time: {:.2f} min".format(epoch, (end-start)/60))

print('Finished Training')


In [None]:
import matplotlib.pyplot as plt

plt.plot(avg_losses)
plt.title("GPT_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

In [None]:
for sentence in out_sentences:
    print("chatbot:", sentence)

In [None]:
'''
PATH = "Put your model weight here!!!"
sentences = ''
model.load_state_dict(torch.load(PATH))
model.eval()
while True:
    print("input q to quit !!")
    inp = input("pleas input your sentence:")
    if inp == "q":
        break
    
    sentences = sentences + inp + '\t'
    with torch.no_grad():
        model_out = model.answer(sentences, INFER_MAX_LEN)
    
    print("inp: {}".format(sentences))
    print('model out:{}'.format(model_out))
    
    sentences += model_out
    sentences += '\t'
    
'''