In [None]:
!pip install -q git+https://github.com/huggingface/transformers
# !pip install transformers
!pip install -q  sentencepiece
!pip install -q  torchinfo

In [None]:
import json
import torch
import glob
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# 保存してあるwikipediaのタイトルのjsonをロード
json_path = 'wiki_title_dict.json'
with open(json_path) as f:
    wiki_ds_json = json.load(f)

In [None]:
def get_llm_sentence(input_text, generator, tokenizer):
  try:
    text = generator(
        f"ユーザー: {input_text}\nシステム: ",
        max_length = 400,
        do_sample = True,
        temperature = 0.7,
        top_p = 0.9,
        top_k = 0,
        repetition_penalty = 1.1,
        num_beams = 1,
        pad_token_id = tokenizer.pad_token_id,
        num_return_sequences = 1,
    )
    sentence = text[0]['generated_text'].split('システム:  ')[-1].replace('\n', '')
    return sentence
  except Exception as e:
    print(input_text)
    print(e)
    try:
      print(text)

    except:
      pass
    return

In [None]:
# tokenizerとmodelを設定
tokenizer = AutoTokenizer.from_pretrained("line-corporation/japanese-large-lm-3.6b-instruction-sft", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("line-corporation/japanese-large-lm-3.6b-instruction-sft")
if device == 'cuda':
  model.cuda()
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)

In [None]:
title_list = wiki_ds_json['title']
n = 100
split_title_list = [title_list[idx:idx + n] for idx in range(0,len(title_list), n)]

In [None]:
# すでに出力済みのファイルを探索
finished_ds_path_list = glob.glob('llm_sentence/*.json')

# ファイルがあれば終了済みの番号を取得
if len(finished_ds_path_list) != 0:
  finished_number_list = [int(i.split('/')[-1].split('.')[0].split('_')[-1]) for i in finished_ds_path_list if ('(' not in i)and (')' not in i)]
  no_file = False

# 無ければフラグをTrueに
else:
  no_file = True

# 開始番号を設定
if no_file:
  start_number = 0
else:
  start_number = max(finished_number_list)+1
number = start_number

# LLMの回答を取得
while len(split_title_list[start_number:])!=0:
  llm_sentence_dict = {}

  # 番号を予約
  path = f'llm_sentence_{number}.json'
  json_file = open(path, mode="w")
  json.dump(llm_sentence_dict, json_file, indent=2)
  json_file.close()

  title_list = split_title_list[number]

  for title in tqdm(title_list):
    input_text = f"{title}に関して、wikipediaの概要を真似て300文字程度で説明してください。"
    llm_sentence_dict[title] = get_llm_sentence(input_text, generator, tokenizer)

  path = f'llm_sentence_{number}.json'
  json_file = open(path, mode="w")
  json.dump(llm_sentence_dict, json_file, indent=2)
  json_file.close()

  # 結果を保存
  finished_ds_path_list = glob.glob('llm_sentence/*.json')
  finished_number_list = [int(i.split('/')[-1].split('.')[0].split('_')[-1]) for i in finished_ds_path_list if ('(' not in i)and (')' not in i)]
  start_number = max(finished_number_list)+1
  number = start_number