In [None]:
!pip install trl

In [None]:
!pip install -q wordfreq fugashi[unidic-lite]

In [None]:
from wordfreq import top_n_list
import fugashi, itertools




In [None]:
import re

ascii_pat = re.compile(r'^[A-Za-z0-9_]+$')   # ← 完全 ASCII を検出
tagger = fugashi.Tagger()
good_pos = {"名詞", "形容詞", "形状詞"}

def is_good(word):
    # ① ASCII なら除外
    if ascii_pat.fullmatch(word):
        return False
    # ② 品詞チェック（名詞・形容詞のみ採用）
    return tagger(word)[0].feature.pos1 in good_pos

raw = [w for w in top_n_list("ja", 10_000) if w.isalpha() and len(w) >= 3]
VOCAB = [w for w in raw if is_good(w)]

In [None]:
import random
from datasets import Dataset

def build_pairs(vocab, size=50_000):
    pairs = []
    for _ in range(size):
        w1 = random.choice(vocab)
        # w1 の末尾かなと同じ頭文字を持つ語だけから w2 を選ぶ
        tail = w1[-1]
        cand = [w for w in vocab if w[0] == tail and w[-1] != "ん"]
        if not cand: continue
        w2 = random.choice(cand)
        pairs.append({"prompt": f"# 指示{w1} → ", "completion": w2})
    return Dataset.from_list(pairs)

dataset = build_pairs(VOCAB)


In [None]:
print(dataset[1])

## 報酬の定義

In [None]:
import re, math, fugashi
import numpy as np
tagger = fugashi.Tagger()
kana_pat = re.compile('[ぁ-ゔー]')

def last_kana(word):
    k = kana_pat.findall(word.lower())
    return k[-1] if k else ''

def shiritori_reward(prompts, completions, **kw):
    rewards = []
    for p, c in zip(prompts, completions):
        head = last_kana(c[:3])
        tail = last_kana(p)
        # 基本点：鎖成功 +1 / 失敗 0
        ok = 1.0 if head and head == tail else -1.0
        # 反則：最後が “ん” なら −1
        penalty = -1.0 if last_kana(c) == 'ん' else 0.0
        # 語彙外は 0.5
        oov = 0.5 if c.strip() not in VOCAB else 0.0
        rewards.append(ok + penalty + oov)
    return rewards




In [None]:
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "google/gemma-3-1b-it"

model = AutoModelForCausalLM.from_pretrained(model_name,  device_map={"": 0} )
tokenizer   = AutoTokenizer.from_pretrained(model_name)



In [None]:
# 500個取得

sample_dataset = dataset.shuffle(seed=42).select(range(0,1000))

In [None]:
sample_dataset


In [None]:
grpo_cfg = GRPOConfig(
    num_generations=8,          # 1 プロンプトにつき兄弟回答 4 本
    temperature=0.8,
    beta=0.02,                 
    max_prompt_length=32,
    max_completion_length=32,
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=sample_dataset,
    reward_funcs=[shiritori_reward],
    args=grpo_cfg,
)

In [None]:
trainer.train()


In [None]:
SAVE_DIR = "./shiritori-model"
trainer.save_model(SAVE_DIR)                # モデルの保存
tokenizer.save_pretrained(SAVE_DIR)              # トークナイザも忘れずに

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline


def play(word):
    ids = tokenizer(word, return_tensors="pt").to(model.device)
    out = model.generate(**ids, max_new_tokens=32, do_sample=True, top_p=0.9)
    return tokenizer.decode(out[0][ids.input_ids.shape[-1]:], skip_special_tokens=True)

print(play("りんご"))   # ==> "ごま"
print(play("みかん"))   # ==> "かに"