In [2]:
# from lora_module import run_lora_finetuning, run_lora_generate, run_lora_test
from lora import load_pretrained_model, load_datasets, train, evaluate, generate

  from .autonotebook import tqdm as notebook_tqdm


# Train/Fine-tuning

## Loading pre-trained model and datasets

In [3]:
# Model args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
lora_layers = 16
seed = 0
resume_adapter_file = None

model_kwargs = {
    'batch_size': 4,
    'iters': 20,
    'adapter_file': "adapters.npz",
    'learning_rate': 1e-5,
    'val_batches': 25,
    'steps_per_report': 10,
    'steps_per_eval': 200,
}

data_path =  './data/'

In [4]:
model, tokenizer, config = load_pretrained_model(model_path, lora_layers, seed, resume_adapter_file)

Total parameters 1244.041M
Trainable parameters 1.704M


In [5]:
train_set, valid_set, test_set = load_datasets(data_path)

In [6]:
print(f'train_set length: {len(train_set)}')
print(f'train_set sample: {train_set[0]}')
print(f'valid_set length: {len(valid_set)}')
print(f'test_set length: {len(test_set)}')

train_set length: 3348
train_set sample: <s>[INST] Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images. Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity. Be empathetic and direct. Look for deeper meanings in the input. Keep the tone practical and straight forward. Q: What is a good sentence a user might want to say given the following words: stars, fascinating?[/INST] A: Stars fascinate me.</s>
valid_set length: 418
test_set length: 419


In [None]:
# run_lora_finetuning(model_path, data_path, lora_layers, batch_size, iters, seed, resume_adapter_file, adapter_file, learning_rate, val_batches, steps_per_report, steps_per_eval)
train(model, train_set, valid_set, tokenizer, **model_kwargs)

# Testing

In [7]:
# Test args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
data_path = './data/'
adapter_file = "adapters.npz"
test_batches = 100
batch_size = 4

In [8]:
# run_lora_test(model_path, data_path, adapter_file, test_batches, batch_size)
test_loss = evaluate(model, test_set, tokenizer, batch_size, test_batches)
print(f"Test loss: {test_loss}")

Test loss: 5.043103066025578


# Generate

In [9]:
# Generate args
model_path = './models/Mistral-7B-Instruct-v0.2-4bit-mlx'
num_tokens = 100
temp = 0.8
adapter_file = "adapters.npz"

prompt = """Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images.
Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity.
- Be empathetic and direct.
- Look for deeper meanings in the input.
- Keep the tone practical and straightforward.
Only give the output for the input provided. Do not come up with new inputs after.
input: chicken, soup, dinner
output: """

In [10]:
generated_text = generate(model, prompt, tokenizer, num_tokens, temp, adapter_file)

Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images.
Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity.
- Be empathetic and direct.
- Look for deeper meanings in the input.
- Keep the tone practical and straightforward.
Only give the output for the input provided. Do not come up with new inputs after.
input: chicken, soup, dinner
output: 
I see a chicken and soup. It looks like we're having dinner tonight.


# Generate with raw prompt
Generating using prompt as-is

In [None]:
# Generate args
# model_path = './models/Mistral-7B-Instruct-v0.2'
model_path = './models/OpenHermes-2.5-Mistral-7B/'
num_tokens = 100
temp = 0.1
adapter_file = "adapters.npz"

prompt = """<s> [INST] Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images.
 Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity.\
 Be empathetic and direct.\
 Look for deeper meanings in the input.\
 Keep the tone practical and straightforward.\
 Only give the output for the input provided. Do not come up with new inputs after.\
 Write the output in first-person.\
 Write the output in simple, direct language. Use 2 sentences or less. Only give the answer, no explanation.\
 Only give one output. Predict the STOP EOS token called </s> after your output.\
 input: love, slippers, outside [/INST]
 output: </s>"""

# Generate with preprocessed prompt
Preprocesses the prompt with ChatML format before generating

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# just adding ChatML chat template here manually because it's SO annoying how this whole model thing is structured
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"


if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def preprocess_chat_ml(sample):

    # for Mistral 7B Instruct v0.2 specifically, because apparently chat template has no "system" part
    # prompt = [
    #     {"role": "user", "content": f"Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images. Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity. Be empathetic and direct. Look for deeper meanings in the input. Keep the tone practical and straight forward. " + input_str},
    #     {"role": "assistant", "content": output_str}
    # ]

    # for OpenHermes 2.5 Mistral 7B specifically
    prompt = [
        {"role": "system", "content": f"Assist a non-verbal autistic individual in communicating their thoughts or needs through selected images. Your task: infer and articulate the message in first-person, using simple, direct language with empathy and clarity. Be empathetic and direct. Look for deeper meanings in the input. Keep the tone practical and straight forward."},
        {"role": "user", "content": sample}
    ]

    return tokenizer.apply_chat_template(prompt, tokenize=False, return_tensors="pt", add_generation_prompt=True)

formatted_prompt = preprocess_chat_ml("love, slippers, outside")
print(formatted_prompt)