diff --git a/pythainlp/gpt/gpt_neo.py b/pythainlp/gpt/gpt_neo.py index 2695bc94b..6a7235e7f 100644 --- a/pythainlp/gpt/gpt_neo.py +++ b/pythainlp/gpt/gpt_neo.py @@ -12,19 +12,23 @@ class ListDataset(Dataset): - def __init__(self, txt_list: List[str], tokenizer: GPT2Tokenizer, max_length: int): + def __init__( + self, txt_list: List[str], tokenizer: GPT2Tokenizer, max_length: int + ): self.input_ids = [] self.attn_masks = [] self.labels = [] for txt in txt_list: encodings_dict = tokenizer( - '<|startoftext|>' + txt + '<|endoftext|>', + '<|start|>' + txt + '<|end|>', truncation=True, max_length=max_length, padding="max_length" ) self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) - self.attn_masks.append(torch.tensor(encodings_dict['attention_mask'])) + self.attn_masks.append( + torch.tensor(encodings_dict['attention_mask']) + ) def __len__(self): return len(self.input_ids) @@ -34,7 +38,9 @@ def __getitem__(self, idx: int): class FewShot: - def __init__(self, model_dir: str, device: str = "cuda", size: str = "125M") -> None: + def __init__( + self, model_dir: str, device: str = "cuda", size: str = "125M" + ): """ :param str model_dir: path of model dir :param str device: device @@ -56,12 +62,14 @@ def init_model(self, size: str = "125M") -> None: self.pretrained = "EleutherAI/gpt-neo-"+str(size) self.tokenizer = GPT2Tokenizer.from_pretrained( self.pretrained, - bos_token='<|startoftext|>', - eos_token='<|endoftext|>', + bos_token='<|start|>', + eos_token='<|end|>', pad_token='<|pad|>' ) self.tokenizer.save_pretrained(self.model_dir) - self.model = GPTNeoForCausalLM.from_pretrained(self.pretrained).to(self.device) + self.model = GPTNeoForCausalLM.from_pretrained( + self.pretrained + ).to(self.device) self.model.resize_token_embeddings(len(self.tokenizer)) def load_model(self): @@ -71,8 +79,8 @@ def load_model(self): self.model_dir = self.model_dir self.tokenizer = GPT2Tokenizer.from_pretrained( self.model_dir, - bos_token='<|startoftext|>', - eos_token='<|endoftext|>', + bos_token='<|start|>', + eos_token='<|end|>', pad_token='<|pad|>' ) self.model = GPTNeoForCausalLM.from_pretrained( @@ -168,7 +176,7 @@ def gen( :rtype: List[str] """ self.generated = self.tokenizer( - '<|startoftext|>' + text, return_tensors="pt" + '<|start|>' + text, return_tensors="pt" ).input_ids.to(self.device) self.sample_outputs = self.model.generate( self.generated, @@ -179,4 +187,8 @@ def gen( temperature=temperature, num_return_sequences=num_return_sequences ) - return [self.tokenizer.decode(i, skip_special_tokens=True).replace('<|startoftext|>','') for i in self.sample_outputs] + return [ + self.tokenizer.decode( + i, skip_special_tokens=True + ).replace("<|start|>", "") for i in self.sample_outputs + ]