In [1]:
import torch
from transformers.optimization import get_scheduler
from transformers.data.data_collator import default_data_collator
from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Model, AdamW

### 加载编码器

In [2]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')

print(tokenizer)

# 添加pad
tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})

# 编码试算
tokenizer.batch_encode_plus([
    'hide new secretions from the parental units',
    'contains no wit, only labored gags'
])

PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})


{'input_ids': [[24717, 649, 3200, 507, 422, 262, 21694, 4991], [3642, 1299, 645, 20868, 11, 691, 2248, 1850, 308, 3775]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

### 加载数据集

In [3]:
# 加载数据
dataset = load_dataset('imdb')

# 重新切分数据集(先将train, test, unsupervised合并, 再划分)
dataset = concatenate_datasets([dataset['train'], dataset['test'], dataset['unsupervised']])

dataset = dataset.train_test_split(test_size=0.01, seed=0)

# 采样, 数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(0).select(range(80000))
dataset['test'] = dataset['test'].shuffle(0).select(range(200))

Found cached dataset imdb (C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-e7af3cac5f3bb93c.arrow and C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-477ff4d300ae4b97.arrow
Loading cached shuffled indices for dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-1cc79241861da5a7.arrow
Loading cached shuffled indices for dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-92ac5a1ec5c26bfa.arrow


### 数据集处理

##### 分词

In [4]:
def f(data, tokenizer):
    # 移除<br/>
    for i in range(len(data['text'])):
        data['text'][i] = data['text'][i].replace('<br /><br />', ' ')
    
    data = tokenizer.batch_encode_plus(data['text'])

    return data

dataset = dataset.map(f,
                      batched=True,
                      num_proc=4,
                      batch_size=1000,
                      remove_columns=['text', 'label'],
                      fn_kwargs={'tokenizer': tokenizer}
)

# 过滤掉太短的句子
def f(data):
    return [sum(i) >= 25 for i in data['attention_mask']]

dataset = dataset.filter(f, batched=True, num_proc=4, batch_size=1000)

         

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-050b01d5730135f9_00000_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-050b01d5730135f9_00001_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-050b01d5730135f9_00002_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-050b01d5730135f9_00003_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-4da8f576c9de64f4_00000_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-4da8f576c9de64f4_00001_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-4da8f576c9de64f4_00002_of_00004.arrow


 

Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1\cache-4da8f576c9de64f4_00003_of_00004.arrow


In [5]:
def f(data):
    block_size = 512

    # 展平数据
    input_ids = []
    for i in data['input_ids']:
        input_ids.extend(i)

    # 切断数据(做成段(句子)(batch))
    data = {'input_ids': [], 'attention_mask': []}
    for i in range(len(input_ids) // block_size):
        block = input_ids[i*block_size: (i+1)*block_size]
        data['input_ids'].append(block)
        data['attention_mask'].append([1] * block_size) # 全部参与attention计算

    # 设置labels
    data['labels'] = data['input_ids'].copy()

    return data

dataset = dataset.map(f, batched=True, batch_size=1000,num_proc=4)
dataset

        

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 44863
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 107
    })
})

### 数据集加载器

In [6]:
loader = torch.utils.data.DataLoader(
    dataset = dataset['train'], 
    batch_size = 4, 
    collate_fn = default_data_collator,
    shuffle = True,
    drop_last = True
)

for i, data in enumerate(loader):
    break

len(loader), data

(11215,
 {'input_ids': tensor([[  340,   351,  6088,  ...,   966,   286,   852],
          [  262,  1388,  3350,  ...,  4556, 13181,    11],
          [  281,   555, 24194,  ...,   339,   468,  4753],
          [  649, 12835,   706,  ...,    11, 27577,  6735]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]),
  'labels': tensor([[  340,   351,  6088,  ...,   966,   286,   852],
          [  262,  1388,  3350,  ...,  4556, 13181,    11],
          [  281,   555, 24194,  ...,   339,   468,  4753],
          [  649, 12835,   706,  ...,    11, 27577,  6735]])})

### 定义模型

In [7]:
# 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = GPT2Model.from_pretrained('gpt2')
        self.fc = torch.nn.Linear(768, tokenizer.vocab_size, bias=False)

        # 加载预训练模型
        parameters = AutoModelForCausalLM.from_pretrained('gpt2')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion =torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels):
        logits = self.pretrained(input_ids=input_ids, attention_mask=attention_mask)
        logits = logits.last_hidden_state
        logits = self.fc(logits)

        shift_logits = logits[:, :-1].flatten(end_dim=1)
        shift_labels = labels[:, 1:].flatten()

        loss = self.criterion(shift_logits, shift_labels)

        return {
            'loss': loss,
            'logits': logits
        }

### 模型试算

In [8]:
model = Model()

# 统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)   # numel : number of elements

with torch.no_grad():
    out = model(**data)
out['loss'], out['logits'].shape

16303.7184
torch.Size([4, 512])
torch.Size([2044])


(tensor(3.7133), torch.Size([4, 512, 50257]))

### 文本生成函数

In [None]:
def generate(text):

    def generate_loop(data):
        with torch.no_grad():
            out = model(**data)
        print(out['logits'].shape)
        
        # 取最后一个字
        # [5, b, 50257]
        out = out['logits']
        # [5, 50257]
        out = out[:, -1]

        # 第50大的值, 以此为分界线, 小于该值的全部赋值为负无穷
        # [5, 50257] -> [5, 50]
        topk_value = torch.topk(out, 50).values
        # [5, 50] -> [5] -> [5, 1]
        topk_value = topk_value[:, -1].unsqueeze(dim=1)

        # 赋值
        # [5, 50257]
        out = out.masked_fill(out < topk_value, -float('inf'))  # 类似与torch.where

        # 根据概率采样, 无放回, 所以不可能重复
        # [5, 50257] -> [5, 1]
        out = out.softmax(dim=1)
        out = out.multinomial(num_samples=1)    # 值就是采样权重(概率), replacement默认为False不放回, 返回该值的下标

        data['input_ids'] = torch.cat([data['input_ids'], out], dim=1)
        data['attention_mask'] = torch.ones_like(data['input_ids'])
        data['labels'] = data['input_ids'].clone()

        if(data['input_ids'].shape[1] >= 30):   # 大于max_length
            return data
        
        return generate_loop(data)

    # 重复5遍
    data = tokenizer.batch_encode_plus([text] * 5, return_tensors = 'pt')
    data['labels'] = data['input_ids'].clone()

    data = generate_loop(data)

    for i in range(5):
        print(i, tokenizer.decode(data['input_ids'][i]))

In [36]:
# generate('I love this')
data = tokenizer(['text'], return_tensors = 'pt')
data

{'input_ids': tensor([[5239]]), 'attention_mask': tensor([[1]])}

### 训练

In [None]:
def train():
    global model
    device = 'cpu' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr = 2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)
    
    model.train()
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 解决梯度爆炸

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if(i % 50 == 0):
            labels = data['labels'][:, 1:]
            out = out['logits'].argmax(dim=2)[:, :-1]

            accuracy = (out == labels).sum().item() / labels.numel()

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), lr, accuracy)

    model = model.to('cpu')

train()
torch.save(model, './models/en_gen.model')

### 使用

In [None]:
# model = torch.load('models/en_gen.model')
generate('I love this')