-
Notifications
You must be signed in to change notification settings - Fork 113
/
gpt2.py
162 lines (129 loc) · 5.85 KB
/
gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# @Time : 2020/12/14
# @Author : Yuanhang Zhou
# @Email : sdzyh002@gmail.com
# UPDATE
# @Time : 2021/1/7
# @Author : Xiaolei Wang
# @email : wxl1999@foxmail.com
r"""
GPT2
====
References:
Radford, Alec, et al. `"Language Models are Unsupervised Multitask Learners."`_.
.. _`"Language Models are Unsupervised Multitask Learners."`:
https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
"""
import os
import torch
from torch.nn import CrossEntropyLoss
from transformers import GPT2LMHeadModel
from crslab.config import PRETRAIN_PATH
from crslab.data import dataset_language_map
from crslab.model.base import BaseModel
from crslab.model.pretrained_models import resources
class GPT2Model(BaseModel):
"""
Attributes:
context_truncate: A integer indicating the length of dialogue context.
response_truncate: A integer indicating the length of dialogue response.
pad_id: A integer indicating the id of padding token.
"""
def __init__(self, opt, device, vocab, side_data):
"""
Args:
opt (dict): A dictionary record the hyper parameters.
device (torch.device): A variable indicating which device to place the data and model.
vocab (dict): A dictionary record the vocabulary information.
side_data (dict): A dictionary record the side data.
"""
self.context_truncate = opt['context_truncate']
self.response_truncate = opt['response_truncate']
self.pad_id = vocab['pad']
language = dataset_language_map[opt['dataset']]
resource = resources['gpt2'][language]
dpath = os.path.join(PRETRAIN_PATH, "gpt2", language)
super(GPT2Model, self).__init__(opt, device, dpath, resource)
def build_model(self):
"""build model"""
self.model = GPT2LMHeadModel.from_pretrained(self.dpath)
self.loss = CrossEntropyLoss(ignore_index=self.pad_id)
def forward(self, batch, mode):
_, _, input_ids, context, _, _, y = batch
if mode != 'test':
# torch.tensor's shape = (bs, seq_len, v_s); tuple's length = 12
lm_logits = self.model(input_ids).logits
# index from 1 to self.reponse_truncate is valid response
loss = self.calculate_loss(
lm_logits[:, -self.response_truncate:-1, :],
input_ids[:, -self.response_truncate + 1:])
pred = torch.max(lm_logits, dim=2)[1] # [bs, seq_len]
pred = pred[:, -self.response_truncate:]
return loss, pred
else:
return self.generate(context)
def generate(self, context):
"""
Args:
context: torch.tensor, shape=(bs, context_turncate)
Returns:
generated_response: torch.tensor, shape=(bs, reponse_turncate-1)
"""
generated_response = []
former_hidden_state = None
context = context[..., -self.response_truncate + 1:]
for i in range(self.response_truncate - 1):
outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s),
last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values
next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s)
preds = next_token_logits.argmax(dim=-1).long() # (bs)
context = preds.unsqueeze(1)
generated_response.append(preds)
generated_response = torch.stack(generated_response).T
return generated_response
def calculate_loss(self, logit, labels):
"""
Args:
preds: torch.FloatTensor, shape=(bs, response_truncate, vocab_size)
labels: torch.LongTensor, shape=(bs, response_truncate)
"""
loss = self.loss(logit.reshape(-1, logit.size(-1)), labels.reshape(-1))
return loss
def generate_bs(self, context, beam=4):
context = context[..., -self.response_truncate + 1:]
context_former = context
batch_size = context.shape[0]
sequences = [[[list(), 1.0]]] * batch_size
for i in range(self.response_truncate - 1):
if sequences != [[[list(), 1.0]]] * batch_size:
context = []
for i in range(batch_size):
for cand in sequences[i]:
text = torch.cat(
(context_former[i], torch.tensor(cand[0]).to(self.device))) # 由于取消了state,与之前的context拼接
context.append(text)
context = torch.stack(context)
with torch.no_grad():
outputs = self.model(context)
last_hidden_state, state = outputs.logits, outputs.past_key_values
next_token_logits = last_hidden_state[:, -1, :]
next_token_probs = torch.nn.functional.softmax(next_token_logits)
topk = torch.topk(next_token_probs, beam, dim=-1)
probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam)
preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam)
for j in range(batch_size):
all_candidates = []
for n in range(len(sequences[j])):
for k in range(beam):
seq = sequences[j][n][0]
prob = sequences[j][n][1]
seq_tmp = seq.copy()
seq_tmp.append(preds[j][n][k])
candidate = [seq_tmp, prob * probs[j][n][k]]
all_candidates.append(candidate)
ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
sequences[j] = ordered[:beam]
res = []
for i in range(batch_size):
res.append(torch.stack(sequences[i][0][0]))
res = torch.stack(res)
return res