Skip to content

Commit

Permalink
cog: raw completion mode for the model
Browse files Browse the repository at this point in the history
  • Loading branch information
nomagick committed Jul 1, 2023
1 parent 53f0106 commit fe2c55c
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
.cog/
.ipynb_checkpoints/
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "chatglm2-6b"]
path = chatglm2-6b
url = https://huggingface.co/THUDM/chatglm2-6b
branch = main
1 change: 1 addition & 0 deletions chatglm2-6b
Submodule chatglm2-6b added at c57e89
34 changes: 34 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
# set to true if your model requires a GPU
gpu: true
cuda: "11.6.2"

# a list of ubuntu apt packages to install
# system_packages:
# - "libgl1-mesa-glx"
# - "libglib2.0-0"

# python version in the form '3.8' or '3.8.12'
python_version: "3.8"

# a list of packages in the format <package-name>==<version>
python_packages:
- "protobuf"
- "transformers==4.30.2"
- "cpm_kernels"
- "torch>=2.0"
- "gradio"
- "mdtex2html"
- "sentencepiece"
- "accelerate"

# commands run after the environment is setup
# run:
# - "echo env is ready!"
# - "echo another command if needed"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
113 changes: 113 additions & 0 deletions patch_chat_glm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import types

import torch

from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList


class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores


@torch.no_grad()
def completion(
self,
tokenizer,
prompt: str,
max_new_tokens: int = 8192,
num_beams=1,
do_sample=True,
top_p=0.8,
temperature=0.8,
logits_processor=None,
**kwargs
):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
inputs = tokenizer([prompt], return_tensors="pt").to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
return response


@torch.no_grad()
def stream_completion(
self,
tokenizer,
prompt: str,
past_key_values=None,
max_new_tokens: int = 8192,
do_sample=True,
top_p=0.8,
temperature=0.8,
logits_processor=None,
return_past_key_values=False,
**kwargs
):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if past_key_values is None and not return_past_key_values:
inputs = tokenizer([prompt], return_tensors="pt").to(self.device)
else:
input_ids = tokenizer.encode("\n\n" + prompt, add_special_tokens=False)
input_ids = input_ids[1:]
inputs = tokenizer.batch_encode_plus(
[(input_ids, None)], return_tensors="pt", add_special_tokens=False
).to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat(
(attention_mask.new_ones(1, past_length), attention_mask), dim=1
)
inputs["attention_mask"] = attention_mask
offset = 0
for outputs in self.stream_generate(
**inputs,
past_key_values=past_key_values,
return_past_key_values=return_past_key_values,
**gen_kwargs
):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
if return_past_key_values:
yield response[offset:], past_key_values
else:
yield response[offset:]
offset = len(response)


def patch(model):
model.stream_completion = types.MethodType(stream_completion, model)
model.completion = types.MethodType(completion, model)
38 changes: 38 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

from cog import BasePredictor, Input, Path, ConcatenateIterator
from transformers import AutoModel, AutoTokenizer

import patch_chat_glm


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.tokenizer = AutoTokenizer.from_pretrained(
"./chatglm2-6b", trust_remote_code=True, local_files_only=True
)
model = AutoModel.from_pretrained(
"./chatglm2-6b", trust_remote_code=True, local_files_only=True
).cuda()
patch_chat_glm.patch(model)
self.model = model.eval()

def predict(
self,
prompt: str = Input(
description="Prompt for completion",
default="[Round 1]\n\n问:请使用英文重复这段话:\"为了使模型生成最优输出,当使用 ChatGLM2-6B 时需要使用特定的输入格式,请按照示例格式组织输入。\"\n\n答:",
),
max_tokens: int = Input(
description="Max new tokens to generate", default=2048, ge=1, le=32768
),
temperature: float = Input(description="Temperature", default=0.75, ge=0, le=5),
top_p: float = Input(description="Top_p", default=0.8, ge=0, le=1),
) -> ConcatenateIterator[str]:
"""Run a single prediction on the model"""

yield from self.model.stream_completion(
self.tokenizer, prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p
)

0 comments on commit fe2c55c

Please sign in to comment.