Skip to content

Commit

Permalink
Update gpt_neo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Jun 16, 2021
1 parent 1616449 commit 3299c7d
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions pythainlp/gpt/gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@


class ListDataset(Dataset):
def __init__(self, txt_list: List[str], tokenizer: GPT2Tokenizer, max_length: int):
def __init__(
self, txt_list: List[str], tokenizer: GPT2Tokenizer, max_length: int
):
self.input_ids = []
self.attn_masks = []
self.labels = []
for txt in txt_list:
encodings_dict = tokenizer(
'<|startoftext|>' + txt + '<|endoftext|>',
'<|start|>' + txt + '<|end|>',
truncation=True,
max_length=max_length,
padding="max_length"
)
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
self.attn_masks.append(
torch.tensor(encodings_dict['attention_mask'])
)

def __len__(self):
return len(self.input_ids)
Expand All @@ -34,7 +38,9 @@ def __getitem__(self, idx: int):


class FewShot:
def __init__(self, model_dir: str, device: str = "cuda", size: str = "125M") -> None:
def __init__(
self, model_dir: str, device: str = "cuda", size: str = "125M"
):
"""
:param str model_dir: path of model dir
:param str device: device
Expand All @@ -56,12 +62,14 @@ def init_model(self, size: str = "125M") -> None:
self.pretrained = "EleutherAI/gpt-neo-"+str(size)
self.tokenizer = GPT2Tokenizer.from_pretrained(
self.pretrained,
bos_token='<|startoftext|>',
eos_token='<|endoftext|>',
bos_token='<|start|>',
eos_token='<|end|>',
pad_token='<|pad|>'
)
self.tokenizer.save_pretrained(self.model_dir)
self.model = GPTNeoForCausalLM.from_pretrained(self.pretrained).to(self.device)
self.model = GPTNeoForCausalLM.from_pretrained(
self.pretrained
).to(self.device)
self.model.resize_token_embeddings(len(self.tokenizer))

def load_model(self):
Expand All @@ -71,8 +79,8 @@ def load_model(self):
self.model_dir = self.model_dir
self.tokenizer = GPT2Tokenizer.from_pretrained(
self.model_dir,
bos_token='<|startoftext|>',
eos_token='<|endoftext|>',
bos_token='<|start|>',
eos_token='<|end|>',
pad_token='<|pad|>'
)
self.model = GPTNeoForCausalLM.from_pretrained(
Expand Down Expand Up @@ -168,7 +176,7 @@ def gen(
:rtype: List[str]
"""
self.generated = self.tokenizer(
'<|startoftext|>' + text, return_tensors="pt"
'<|start|>' + text, return_tensors="pt"
).input_ids.to(self.device)
self.sample_outputs = self.model.generate(
self.generated,
Expand All @@ -179,4 +187,8 @@ def gen(
temperature=temperature,
num_return_sequences=num_return_sequences
)
return [self.tokenizer.decode(i, skip_special_tokens=True).replace('<|startoftext|>','') for i in self.sample_outputs]
return [
self.tokenizer.decode(
i, skip_special_tokens=True
).replace("<|start|>", "") for i in self.sample_outputs
]

0 comments on commit 3299c7d

Please sign in to comment.