In [1]:
import os
import pandas as pd
import torch
from utils import get_tokenizer
from vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, Dataset


device = "cuda" if torch.cuda.is_available() else "cpu"
data_types = ["train", "test"]
wiki_datas = {}
for data_type in data_types:
    df = pd.read_csv(f"{data_type}.csv")
    print(len(df))
    df.columns = ["idx", "text"]
    res = []
    for i, item in df.iterrows():
        text = item.to_dict()["text"]
        res.append(text)
    wiki_datas[data_type] = res

def read_wikitext(data_type):
    return wiki_datas[data_type]

def yield_tokens(data_iter):
    for sentence in data_iter:
        # tokenizer 处理一个语句（包含多个单词）
        yield tokenizer(sentence)

tokenizer = get_tokenizer('basic_english')
train_iter = read_wikitext("train")
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<pad>", "<sos>", "<eos>"])
vocab.set_default_index(vocab['<pad>'])

print("词汇表大小",     len(vocab))
print("词汇示例(word2idx)", vocab['am'])

17514
2181
词汇表大小 65987
词汇示例(word2idx) 1707


In [2]:
len(vocab)

65987

In [3]:
class ChatDataset(Dataset):
    def __init__(self, file_path, tokenizer, vocab):
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.input_data, self.target_data = self.load_chat_data(file_path)

    def load_chat_data(self, file_path):
        lines = []
        with open(file_path, "r") as f:
            lines = f.readlines() 
        input_data, target_data = [], []
        for i, line in enumerate(lines):
            if line.startswith("User:"):
                prefix = "User: "
                tokens = self.tokenizer(line.strip()[len(prefix):])
                tokens = ["<sos>"] + tokens + ["<eos>"]
                idxs = [self.vocab[token] for token in tokens]
                input_data.append(torch.tensor(idxs, dtype=torch.long))
            elif line.startswith("AI:"):
                prefix = "AI: "
                tokens = self.tokenizer(line.strip()[len(prefix):])
                tokens = ["<sos>"] + tokens + ["<eos>"]
                idxs = [self.vocab[token] for token in tokens]
                target_data.append(torch.tensor(idxs, dtype=torch.long))

        return input_data, target_data

    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx]

    def get_token_strs(self, tokens):
        return [self.vocab.get_itos()[token] for token in tokens]
        
    def item(self, idx):
        inp, tgt = self.input_data[idx], self.target_data[idx]
        
        inp_strs = self.get_token_strs(inp)
        tgt_strs = self.get_token_strs(tgt)
        inp_s = " ".join(inp_strs)
        tgt_s = " ".join(tgt_strs)
        return inp_s, tgt_s

In [4]:
file_path = "chat.txt"
chat_dataset = ChatDataset(file_path, tokenizer, vocab)

for i in range(3):
    input_s, target_s = chat_dataset.item(i)
    input_sample, target_sample = chat_dataset[i]
    print(f"{i+1}个case")
    print("input sentence:", input_s)
    print("input token:", input_sample)
    print("input len", len(input_sample))
    print("target sentence:", target_s)
    print("target data:", target_sample)
    print("target len", len(target_sample))
    print("=" * 30)

1个case
input sentence: <sos> hi , how are you ? <eos>
input token: tensor([   1, 9635,    4,  412,   35,  178,  853,    2])
input len 8
target sentence: <sos> i am doing well , thank you . how about you ? <eos>
target data: tensor([    1,    65,  1707,  1616,   120,     4, 14003,   178,     5,   412,
           73,   178,   853,     2])
target len 14
2个case
input sentence: <sos> i am good , thanks for asking . what can you do ? <eos>
input token: tensor([   1,   65, 1707,  416,    4, 6372,   18, 4148,    5,  185,  112,  178,
         283,  853,    2])
input len 15
target sentence: <sos> i am an ai language model . i can help you answer questions . <eos>
target data: tensor([   1,   65, 1707,   31, 2051,  840, 1681,    5,   65,  112,  634,  178,
        5949, 4186,    5,    2])
target len 16
3个case
input sentence: <sos> what is the weather like today ? <eos>
input token: tensor([   1,  185,   24,    3, 1504,  139,  802,  853,    2])
input len 9
target sentence: <sos> please check a weat

In [5]:
def pad_sequence(sequences, padding_value=0, length=None):
    """
    填充序列，目的sequences token序列长度相同
    """
    max_length = max(len(seq) for seq in sequences) if length is None else length
    # 全零张量
    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)

    for i, seq in enumerate(sequences):
        end = len(seq)
        result[i, :end] = seq[:end]
    return result

def collate_fn(batch):
    """
    对batch数据进行预处理，让其src、tgt长度一致
    """
    # batch: [(src1, tgt1), (src2, tgt2), ...]
    sources, targets = zip(*batch)

    tmps = []
    tmps.extend(sources)
    tmps.extend(targets)
    max_length = max([len(s) for s in tmps])
    # print("max_length", max_length)
    pad_val = vocab["<pad>"]

    sources = pad_sequence(sources, padding_value=pad_val, length=max_length)
    targets = pad_sequence(targets, padding_value=pad_val, length=max_length)
    return sources, targets

In [6]:
batch_size = 2
chat_dataloader = DataLoader(chat_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [7]:
from gpt_model import GPT
voc_size = len(vocab)
max_seq_len = 256
n_layers = 6

model = GPT(voc_size, max_seq_len, n_layers)
model.load_state_dict(torch.load("wikitext_best.pth", map_location=torch.device("cpu")))
print(model)

GPT(
  (decoder): Decoder(
    (src_emb): Embedding(65987, 512)
    (pos_emb): Embedding(256, 512)
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
        (self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (linear): Linear(in_features=512, out_features=512, bias=True)
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,))
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
 

In [11]:
import torch.nn as nn
import torch.optim as optim

learning_rate = 1e-4

loss_fn = nn.CrossEntropyLoss() # 损失函数

def freeze_layers(model, n):
    params_to_update = [] # 参数
    for name, param in model.named_parameters():
        # print(name) # debug 看下参数名
        # try:
        #     if int(name.split(".")[2]) >= n:
        #         params_to_update.append(param)
        # except:
        #     pass
        params_to_update.append(param)
    return params_to_update

params_to_update = freeze_layers(model, n=2) # 冻结前两层
# print(len(params_to_update))
optimizer = optim.Adam(params_to_update, lr=learning_rate)

min_loss = float("inf")
save_path = "light_chatgpt_best.pth"

for epoch in range(200): # 开始训练
    for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader):
        optimizer.zero_grad()
        # print(input_batch.shape)
        # print(target_batch.shape)
        # print(input_batch)
        # print(target_batch)
        input_batch, target_batch = input_batch.to(device), target_batch.to(device)
        outputs = model(input_batch)
        # print(outputs.view(-1, len(vocab)).shape)
        # print(target_batch.view(-1).shape)
        loss = loss_fn(outputs.view(-1, len(vocab)), target_batch.view(-1))
        single_loss = loss.item() / len(input_batch)
        if single_loss < min_loss:
            min_loss = single_loss
            torch.save(model.state_dict(), save_path)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch: {epoch + 1:04d}, cost = {loss:6f}")

Epoch: 0020, cost = 0.263370
Epoch: 0040, cost = 0.020510
Epoch: 0060, cost = 0.003986
Epoch: 0080, cost = 0.557733
Epoch: 0100, cost = 0.090689
Epoch: 0120, cost = 0.006321
Epoch: 0140, cost = 0.071676
Epoch: 0160, cost = 0.024505
Epoch: 0180, cost = 0.002932
Epoch: 0200, cost = 0.024535


In [12]:
model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [14]:
# 集束搜索
def print_candidate(candidate, prefix=""):
    special_tokens = {'<pad>', '<eos>', '<bos>', '<unk>'}
    s = " ".join([vocab.get_itos()[token] for token in candidate if vocab.get_itos()[token] not in special_tokens])
    print(prefix + f"序列: {s}")

def generate_text_beam_search(model, input_str, max_len=20, beam_width=5, debug=False):
    # model.eval()
    input_tokens = [vocab[token] for token in input_str.split()]
    # 初始化候选列表
    candidates = [(input_tokens, 0.0)]
    if debug:
        print(len(input_tokens))
    if debug:
        print_candidate(candidates[0][0], prefix="输入")
    with torch.no_grad():
        final_results = []
        for i in range(max_len): # 最多max_len个token
            new_candidates = []
            for candidate, candidate_score in candidates:
                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)
                # inputs: [1, seq_len]
                outputs = model(inputs)
                # outputs: [1, seq_len, vocab_size]
                logits = outputs[:, -1, :] # 只关心最后一步的数据
                # logits [1, vocab_size]
                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)
                # scores: [1, beam_width]
                # next_tokens: [1, beam_width]
                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):
                    new_candidate = candidate + [next_token.item()]
                    new_score = candidate_score + score.item()
                    if next_token.item() == vocab['<eos>']:
                        final_results.append((new_candidate, new_score))
                    else:
                        new_candidates.append((new_candidate, new_score))
            # print(f"第{i+1}次预测, 共有 {len(new_candidates)} 个候选 {len(final_results)}个结果集")
            # 从新生成的候选中选择最好的 beam_width 个
            candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
            if debug:
                print_candidate(candidates[0][0])
            # print(f"最佳候选序列的token: {[vocab.get_itos()[token] for token in best_candidate]}")
        # 将过程中的遇到<eos>的结果数据也放到候选中
        candidates.extend(final_results)
        best_candidate, _ = sorted(candidates, key=lambda x: x[1], reverse=True)[0]

        special_tokens = {'<pad>', '<eos>', '<bos>', '<unk>'}
        best_candidate_strs = [vocab.get_itos()[token] for token in best_candidate if vocab.get_itos()[token] not in special_tokens]
        
        if debug:
            print(len(best_candidate))
        return ' '.join(best_candidate_strs)


input_strs = ["what is the weather like today ?", "hi , how are you ?"]
for inp_s in input_strs:
    gen_text = generate_text_beam_search(model, inp_s)
    print("input_str", inp_s)
    print("gen_text", gen_text)
    print("=" * 30)

input_str what is the weather like today ?
gen_text what is the weather like today ? application weather weather weather weather weather weather weather weather weather weather weather weather weather weather weather weather weather weather weather
input_str hi , how are you ?
gen_text hi , how are you ? thank you , am am am am am am am am am am am am am am am am am
