In [8]:
import json
import torch
from tqdm import tqdm
from transformers import BertTokenizer, GPT2LMHeadModel

In [2]:
device = 'cuda' if(torch.cuda.is_available()) else 'cpu'
device

'cuda'

### 加载模型

In [3]:
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

### 数据集

In [6]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        """输入data格式:
            [{'event': '......', 'attitude': '......', 'suggestion': ...}, {}, ..., {}]
        """
        super().__init__()
        
        self.base_prompt = '对于"{}"这件事，大家都表现出"{}"的心态。对于此心态的出现，引导建议为：\n'
        self.data = []
        for d in data:
            prompt = self.base_prompt.format(d['event'], d['attitude'])
            suggestion = d['suggestion']
            self.data.append({'prompt': prompt, 'suggestion': suggestion})
        
        test_rate = 0.1
        self.train_data = self.data[:int(len(self.data) * (1-test_rate))]
        self.valid_data = self.data[int(len(self.data) * (1-test_rate)):]
        
        self.train = True
    
    def __len__(self):
        if(self.train):
            return len(self.train_data)
        else:
            return len(self.valid_data)
    
    def __getitem__(self, idx):
        if(self.train):
            return self.train_data[idx]
        else:
            return self.valid_data[idx]

data = json.load(open('./dataset/lead_opinions/suggestion.json', 'r', encoding='utf-8')) + json.load(open('./dataset/lead_opinions/total_chatgpt.json', 'r', encoding='utf-8'))
dataset = Dataset(data)

In [5]:
def collate_fn(data):
    text = [(i['prompt'] + i['suggestion']).replace('\n', '[SEP]') for i in data]
    labels = [i for i in text]
    
    text = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True, add_special_tokens=True, max_length=512, truncation=True)
    labels = tokenizer.batch_encode_plus(labels, return_tensors="pt", padding=True, return_attention_mask=True, add_special_tokens=True)
    
    data = {}
    data['input_ids'] = text['input_ids'][:,1:].to(device)
    data['attention_mask'] = text['attention_mask'][:,1:].to(device)
    data['labels'] = labels['input_ids'][:,1:].to(device)

    return data

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     drop_last=True)

### 训练

In [None]:
def train(model, epoches, lr):
    lens = len(loader)
    model = model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    losses = torch.zeros((epoches, lens))
    for i in range(epoches):
        with tqdm(total=lens, ncols=100) as bar:
            bar.set_description('训练进度-epoch: {}/{}'.format(i+1,epoches))
            for n,d in enumerate(loader):
                loss = model(**d)['loss']
                loss.backward()
                optim.step()
                optim.zero_grad()
                 
                losses[i,n] += loss.item()
                bar.update(1)
            bar.set_postfix(loss = '{:.4f}'.format(losses[i].mean().item()))
    model.cpu()
    return losses

losses = train(model, 10, 2e-5)
torch.save(model, './models/lead_opinions2.model')

### 使用

In [4]:
# 导入模型
model = torch.load('./models/lead_opinions.model')

In [9]:
model.eval().to(device)

event = '很喜欢宫崎骏的一段话'
attitude = '高兴'
message = dataset.base_prompt.format(event, attitude)
inputs = tokenizer.encode_plus(message, return_tensors='pt', add_special_tokens=True).to(device)

output = tokenizer.decode(model.generate(inputs['input_ids'], max_length=350, do_sample=True, num_beams=5)[0])
output = output.replace(' ', '').replace('1', '\n1').replace('2', '\n2').replace('3', '\n3').replace('4', '\n4').replace('5', '\n5').strip()
pure = output[:output.find('[PAD]')].replace('[CLS]', '').replace('[SEP]', '').replace('[UNK]', '')
print('API生成: ', pure)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


API生成:  对于"很喜欢宫崎骏的一段话"这件事，大家都表现出"高兴"的心态。对于此心态的出现，引导建议为：
1.鼓励分享：鼓励大家积极分享自己对宫崎骏的喜爱和欣赏，包括绘画、摄影等方面。
2.强调尊重：提醒大家要尊重不同的审美观点和文化背景，不要将自己的观点强加于他人。
3.提倡珍惜：引导大家珍惜这段喜欢宫崎骏的经历，并用心地品味其中的精神内涵。
4.引导理性思考：提醒大家要理性思考自己的行为举止是否符合社会道德规范。同时也要尊重他人的隐私和权益。
