In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [ ]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1

config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["w1", "w2", "w3"],  # just targetting the MoE layers.
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

In [None]:
import pickle


def pickle_to_tensor(x):
    deserial = pickle.loads(bytes.fromhex(x.replace("\\x", "")))
    return deserial

In [None]:
from datasets import load_dataset
import os

parquet_dir = "data"
train_parquet_file = os.path.join(parquet_dir, "train.parquet")

initial_dataset = load_dataset(
    "parquet", data_files=train_parquet_file, streaming=True
).remove_columns("token_id")

In [None]:
def transform(examples):
    examples["attention_mask"] = [
        pickle_to_tensor(x) for x in examples["attention_mask"]
    ]
    examples["input_ids"] = [pickle_to_tensor(x) for x in examples["input_ids"]]

    return examples


dataset = initial_dataset.map(transform, batched=True)

In [None]:
print("dataset", dataset)
train_data = dataset["train"]

In [None]:
from transformers import DataCollatorForLanguageModeling


class CustomDataCollator(DataCollatorForLanguageModeling):
    def __call__(self, examples):
        batch = {"input_ids": [], "attention_mask": []}
        for example in examples:
            batch["input_ids"].append(example["input_ids"][0])
            batch["attention_mask"].append(example["attention_mask"][0])
        batch = tokenizer.pad(batch, return_tensors="pt", padding="longest")

        return batch

In [ ]:
data_collator = CustomDataCollator(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    train_dataset=train_data,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_train_epochs=6,
        learning_rate=1e-4,
        logging_steps=2,
        optim="adamw_torch",
        output_dir="aidx-mixtral",
    ),
    data_collator=data_collator,
)

trainer.train()

In [1]:
from evaluate import load

bertscore = load("bertscore")



In [15]:
bertscore.compute(predictions=['hello world'], references=['hi there!'], model_type='nfliu/scibert_basevocab_uncased')

{'precision': [0.56596839427948],
 'recall': [0.5439147353172302],
 'f1': [0.5547224879264832],
 'hashcode': 'nfliu/scibert_basevocab_uncased_L9_no-idf_version=0.3.12(hug_trans=4.36.2)'}