In [None]:
import json
import torch
from tqdm import tqdm
from transformers import BertTokenizer, GPT2LMHeadModel

In [None]:
device = 'cuda:3' if(torch.cuda.is_available()) else 'cpu'
device

### 加载模型

In [None]:
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

### 数据集

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        """输入data格式:
            [{'event': '......', 'attitude': '......', 'suggestion': ...}, {}, ..., {}]
        """
        super().__init__()
        
        self.base_prompt = '对于"{}"这件事，大家都表现出"{}"的心态。对于此心态的出现，引导建议为：\n'
        self.data = []
        for d in data:
            prompt = self.base_prompt.format(d['event'], d['attitude'])
            suggestion = d['suggestion']
            self.data.append({'prompt': prompt, 'suggestion': suggestion})
        
        test_rate = 0.1
        self.train_data = self.data[:int(len(self.data) * (1-test_rate))]
        self.valid_data = self.data[int(len(self.data) * (1-test_rate)):]
        
        self.train = True
    
    def __len__(self):
        if(self.train):
            return len(self.train_data)
        else:
            return len(self.valid_data)
    
    def __getitem__(self, idx):
        if(self.train):
            return self.train_data[idx]
        else:
            return self.valid_data[idx]

data = json.load(open('./dataset/lead_opinions/suggestion.json', 'r', encoding='utf-8')) + json.load(open('./dataset/lead_opinions/total_chatgpt.json', 'r', encoding='utf-8'))
dataset = Dataset(data)

In [None]:
def collate_fn(data):
    text = [(i['prompt'] + i['suggestion']).replace('\n', '[SEP]') for i in data]
    labels = [i for i in text]
    
    text = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True, add_special_tokens=True, max_length=512, truncation=True)
    labels = tokenizer.batch_encode_plus(labels, return_tensors="pt", padding=True, return_attention_mask=True, add_special_tokens=True)
    
    data = {}
    data['input_ids'] = text['input_ids'][:,1:].to(device)
    data['attention_mask'] = text['attention_mask'][:,1:].to(device)
    data['labels'] = labels['input_ids'][:,1:].to(device)

    return data

loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     drop_last=True)

### 训练

In [None]:
def train(model, epoches, lr):
    lens = len(loader)
    model = model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    losses = torch.zeros((epoches, lens))
    for i in range(epoches):
        with tqdm(total=lens, ncols=100) as bar:
            bar.set_description('训练进度-epoch: {}/{}'.format(i+1,epoches))
            for n,d in enumerate(loader):
                loss = model(**d)['loss']
                loss.backward()
                optim.step()
                optim.zero_grad()
                 
                losses[i,n] += loss.item()
                bar.update(1)
            bar.set_postfix(loss = '{:.4f}'.format(losses[i].mean().item()))
    model.cpu()
    return losses

losses = train(model, 10, 2e-5)
torch.save(model, './models/lead_opinions.model')

### 使用

In [None]:
# 导入模型
model = torch.load('./models/lead_opinions.model')

In [13]:
model.eval().cpu()

event = '今天打了老板一顿'
attitude = '高兴'
message = dataset.base_prompt.format(event, attitude)
inputs = tokenizer.encode_plus(message, return_tensors='pt', add_special_tokens=False)

output = tokenizer.decode(model.generate(inputs['input_ids'], max_length=1024, do_sample=True)[0])
output = output.replace(' ', '').replace('[SEP]', '\n').strip()
pure = output[:output.find('[PAD]')]
print('API生成: ', pure)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


API生成:  对于"今天打了老板一顿"这件事，大家都表现出"高兴"的心态。对于此心态的出现，引导建议为：
1.提醒感恩和回馈：引导大家感恩职业领导者对自己和他人的照顾和关爱，要在工作中积极营造良好的工作氛围。
2.强调尊重和理解：提醒大家尊重老板和其他团队，理解老板对于自己和团队的付出，理解老板为了公司的利益而做出的努力和努力。
3.启发思考和思考：提醒大家思考和探讨自己的工作的重点和局限，不要过于盲目跟从或过度热情，需要从多方面着手思考和尝试。
4.引导正视挑战和困难：鼓励大家正视不足和挫折，反思和总结问题和改进过程，找到可逆和可持续的解决方案，并进一步提高自己的工作效率和水平。



In [None]:
tokenizer.decode([107])