-
Notifications
You must be signed in to change notification settings - Fork 55
/
generators.py
73 lines (61 loc) · 3.57 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedModel, GenerationConfig
from transformers import Pipeline, pipeline
import torch
class GeneratorBase:
def generate(self, query: str, parameters: dict) -> str:
raise NotImplementedError
def __call__(self, query: str, parameters: dict = None) -> str:
return self.generate(query, parameters)
class StarCoder(GeneratorBase):
def __init__(self, pretrained: str, device: str = None, device_map: str = None):
self.pretrained: str = pretrained
self.pipe: Pipeline = pipeline(
"text-generation", model=pretrained, torch_dtype=torch.bfloat16, device=device, device_map=device_map)
self.generation_config = GenerationConfig.from_pretrained(pretrained)
self.generation_config.pad_token_id = self.pipe.tokenizer.eos_token_id
def generate(self, query: str, parameters: dict) -> str:
config: GenerationConfig = GenerationConfig.from_dict({
**self.generation_config.to_dict(),
**parameters
})
json_response: dict = self.pipe(query, generation_config=config)[0]
generated_text: str = json_response['generated_text']
return generated_text
class SantaCoder(GeneratorBase):
def __init__(self, pretrained: str, device: str = 'cuda'):
self.pretrained: str = pretrained
self.device: str = device
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained, trust_remote_code=True)
self.model.to(device=self.device)
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
self.generation_config: GenerationConfig = GenerationConfig.from_model_config(self.model.config)
self.generation_config.pad_token_id = self.tokenizer.eos_token_id
def generate(self, query: str, parameters: dict) -> str:
input_ids: torch.Tensor = self.tokenizer.encode(query, return_tensors='pt').to(self.device)
config: GenerationConfig = GenerationConfig.from_dict({
**self.generation_config.to_dict(),
**parameters
})
output_ids: torch.Tensor = self.model.generate(input_ids, generation_config=config)
output_text: str = self.tokenizer.decode(
output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return output_text
class ReplitCode(GeneratorBase):
def __init__(self, pretrained: str, device: str = 'cuda'):
self.pretrained: str = pretrained
self.device: str = device
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(pretrained, trust_remote_code=True)
self.model.to(device=self.device, dtype=torch.bfloat16)
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
self.default_parameter: dict = dict(
do_sample=True, top_p=0.95, top_k=4, pad_token_id=self.tokenizer.eos_token_id,
temperature=0.2, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id
)
def generate(self, query: str, parameters: dict = None) -> str:
input_ids: torch.Tensor = self.tokenizer.encode(query, return_tensors='pt').to(self.device)
params = {**self.default_parameter, **(parameters or {})}
params.pop('stop')
output_ids: torch.Tensor = self.model.generate(input_ids, **params)
output_text: str = self.tokenizer.decode(
output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return output_text