In [None]:
!pip install transformers
!pip install sentencepiece
!pip install nn_pruning

In [None]:
from transformers import AutoModelForCausalLM
model_name = "rinna/japanese-gpt-1b"
model = AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [None]:
import torch
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

In [None]:
from torch import nn
import torch.nn.utils.prune as prune

PRUNE_RATE = 0.2

def prune_transform(model: nn.Module) -> nn.Module:
  for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=PRUNE_RATE)
        prune.remove(module, "weight")
  return model

In [None]:
model = prune_transform(model)

In [None]:
import time
class ChatBot(torch.nn.Module):

   def __init__(self):
         super(ChatBot, self).__init__()
         

   #文章生成を行う関数。元になる文章、最大文字数、最小文字数を引数にもつ。
   def generate(self, text, max_length, min_length):
     token_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
     with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            max_new_tokens=max_length,
            min_new_tokens=min_length,
            do_sample=True,
            top_k=500,
            top_p=0.95,
            padding="do_not_pad",
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_word_ids=[[tokenizer.unk_token_id],"..."]
        )
        output = tokenizer.decode(output_ids.tolist()[0])
        return output
    

   def chat(self):
     #プロフィール設定
     name = input("AIの名前:")
     name_text = f"あなたは{name}で、名前は{name}といいます。"
     hobby = input("AIの趣味:")
     hobby_text = f"{name}の趣味は{hobby}で、休日は{hobby}をして過ごしています。"
     work = input("AIの職業:")
     work_text = f"{name}の職業は{work}で、{work}として生活しています。"

     setting_text_1 = f"{name}:「今日はいい天気ですわね。」"
     setting_text_2 = f"{name}:「わたくしは{name}ですわ！"

     print("AIに言いたい事を入力してください。終了したいときは未入力のままEnter")
     userInput = "ッ"
     text = name_text + hobby_text + work_text + setting_text_1 + setting_text_2 + f"以下は人間と{name}の会話です。人間:「こんにちは!」{name}:「よろしくお願いしますわ！」人間:「"
     max_length = 25
     min_length = 20
     

     while userInput != "":
       userInput = input(">>> ")
       start = time.time()
       if userInput == "":
           print("会話を終了します")
           break
       text += userInput + f"」{name}:「"

       token_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
       print(len(token_ids[0]))
       #文字数調節
       message_gap = len(token_ids[0]) - max_length
       output = self.generate(text,max_length,min_length)
       

       #半角正則化
       text = text.translate(str.maketrans({chr(0xFF01 + i): chr(0x21 + i) for i in range(94)}))
   
       #今回の応答より前を取得
       output = output.replace(text, "")
       print(output)
       #最初の」までを分割する
       outputList = []
       for l in output:
        outputList.append(l)
        if l == "」":
            break
       outputSentence = "".join(outputList)
       text += outputSentence + "人間:「"
       message = outputSentence.replace("」", "")
       time.sleep(1)
       print(time.time() - start)
       print(message)
      

In [None]:
bot = ChatBot()
bot.chat()