## Toki-Pona-ChatGLM-Finetune

This notebook implements finetuning of ChatGLM on Toki Pona content to attempt to produce an LLM capable of speaking Toki Pona.

In [None]:
!python -m pip install -r requirements.txt

In [None]:
import os
import dataset.GLM 
import torch
import loralib as lora
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from lora_utils.insert_lora import get_lora_model
import dataset.GLM 


In [None]:


device = 'cpu'
checkpoint = "THUDM/chatglm-6b"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, revision = 'main')
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True, revision = 'main')


In [None]:
lora_config = {
    'r': 8,
    'lora_alpha':16,
    'lora_dropout':0.1,
    'enable_lora':[True, False, True],
}

model = get_lora_model(model, lora_config)
model.load_state_dict(torch.load('saved/chatglm-6b_alpaca_5.pt'), strict=False)

In [None]:

dataset.GLM.device = device
#dataset.GLM.pad_to = 8

In [None]:
pairs = [
    {'prompt':'toki!', 'completion':'toki! mi pona tan ni: mi sona e sina.'}
]
pairs_encoded = dataset.GLM.encode_pairs(pairs, tokenizer)
train_dataset = dataset.GLM.SimpleDataset(pairs_encoded)
train_dataloader = DataLoader(dataset=train_dataset, collate_fn = dataset.GLM.collate_fn, shuffle=True, batch_size=1)

In [None]:
model.half().to(device)

In [None]:
batch = {k: v.to(device) for k, v in next(iter(train_dataloader)).items()}

In [None]:
model(**batch).loss

## Inference


In [None]:
pairs = [
    {'prompt':'toki!', 'completion':'toki! mi pona tan ni: mi sona e sina.'}
]

pairs_encoded = dataset.GLM.encode_pairs(pairs, tokenizer, with_eos=False)
test_dataset = dataset.GLM.SimpleDataset(pairs_encoded)
test_dataloader = DataLoader(dataset=test_dataset, collate_fn = dataset.GLM.collate_fn, shuffle=True, batch_size=1)

In [None]:
batch = {k: v.to(device) for k, v in next(iter(test_dataloader)).items()}

In [None]:
outputs = model.generate(
    **batch, 
    max_length=1024,
    eos_token_id=130005,
    do_sample=True,
    temperature=0.55,
    top_p = 0.75,
    top_k = 10000,
    repetition_penalty=1.5, 
    num_return_sequences=1,
)

In [None]:
for output in outputs:
    print(tokenizer.sp_tokenizer.decode(output))