In [1]:
from lora import generate, load_model
from models import ModelArgs, Model, LoRALinear
from mlx.utils import tree_map, tree_flatten, tree_unflatten

In [2]:
class Args:
    def __init__(self, prompt, model_path="mistral-mlx", adapter_path="adapters.npz", num_tokens=2, temp=0):
        self.model = model_path
        self.adapter_file = adapter_path
        self.num_tokens = num_tokens
        self.prompt = prompt
        self.temp = temp


In [3]:
model, tokenizer = load_model("mistral-mlx")

# Freeze all layers other than LORA linears
model.freeze()
for l in model.layers[-8:]:
    l.attention.wq = LoRALinear.from_linear(l.attention.wq)
    l.attention.wv = LoRALinear.from_linear(l.attention.wv)

p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

   # Load the LoRA adapter weights which we assume should exist by this point
model.load_weights("adapters.npz")

Total parameters 7242.584M
Trainable parameters 0.852M


In [4]:
import json
data = []
with open('data/test.jsonl', 'r') as f:
    for line in f:
        data.append(json.loads(line))



In [5]:
data = [(item['text'].rsplit(' ', 1)[0], item['text'].rsplit(' ', 1)[1]) for item in data]
# shuffle data
from random import shuffle

shuffle(data)

In [6]:
from sklearn.metrics import classification_report

results_by_temp = {}
for temperature in [0]:
    y_true = []
    y_pred = []
    for text, label in data[:200]:
        args = Args(prompt=text, temp=temperature, num_tokens=1)
        res = generate(model, text, tokenizer, args)
        if "True" in label:
            y_true.append(True)
        else:
            y_true.append(False)
        if "True" in res:
            y_pred.append(True)
        else:
            y_pred.append(False)

    print(f"Classification report for temperature {temperature}:")
    print(classification_report(y_true, y_pred))
    results_by_temp[temperature] = classification_report(y_true, y_pred, output_dict=True)


Classification report for temperature 0:
              precision    recall  f1-score   support

       False       0.46      0.39      0.42        95
        True       0.52      0.59      0.55       105

    accuracy                           0.49       200
   macro avg       0.49      0.49      0.49       200
weighted avg       0.49      0.49      0.49       200



In [23]:
results_by_temp

{0: 0.645, 0.3: 0.62, 0.7: 0.575, 0.9: 0.575}