Skip to content

Commit

Permalink
Assistant token counter
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Sep 23, 2023
1 parent 412c52c commit f0ca8cc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The MIT License (MIT)

Copyright (c) Artit Wangperawong
Copyright (c) artitw

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

setuptools.setup(
name="text2text",
version="1.3.0",
author="Artit Wangperawong",
version="1.3.1",
author="artitw",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
long_description=long_description,
Expand Down
12 changes: 11 additions & 1 deletion text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, **kwargs):
quantize_config=None
)

def transform(self, input_lines, src_lang='en', retriever=None, **kwargs):
def preprocess(self, input_lines, src_lang='en', retriever=None, **kwargs):
input_lines = t2t.Transformer.transform(self, input_lines, src_lang, **kwargs)
df = pd.DataFrame({"input_line": input_lines})
if src_lang != 'en':
Expand All @@ -31,6 +31,16 @@ def transform(self, input_lines, src_lang='en', retriever=None, **kwargs):
df["knowledge"] = retriever.retrieve(df["input_line"].str.lower().tolist(), k=k)
df["input_line"] = df["knowledge"].apply(' '.join) + " - " + df["input_line"]
df["input_line"] = "USER: " + df["input_line"] + "\nASSISTANT:"
return df

def num_tokens(self, input_lines, src_lang='en'):
df = self.preprocess(input_lines, src_lang)
tok = self.__class__.tokenizer
input_ids = tok(df["input_line"].tolist(), return_tensors="pt", padding=True).input_ids
return len(input_ids[0])

def transform(self, input_lines, src_lang='en', retriever=None, **kwargs):
df = self.preprocess(input_lines, src_lang, retriever, **kwargs)
temperature = kwargs.get('temperature', 0.7)
top_p = kwargs.get('top_p', 0.95)
top_k = kwargs.get('top_k', 0)
Expand Down

0 comments on commit f0ca8cc

Please sign in to comment.