In [None]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

MODEL_ID = "LoftQ/Mistral-7B-v0.1-4bit-64rank"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,  
        bnb_4bit_use_double_quant=False,
        bnb_4bit_quant_type='nf4',
    ),
    device_map={"":0}
)
base_model.resize_token_embeddings(len(tokenizer))

peft_model = PeftModel.from_pretrained(
    base_model,
    MODEL_ID,
    subfolder="loftq_init",
    is_trainable=True,
)

In [None]:
from datamodule import datamodule

#path = ["/home/elicer/M-LLM/data/BoolQA.csv", "/home/elicer/M-LLM/data/NLI_CB.csv", "/home/elicer/M-LLM/data/sc_amazon.csv"]
path = "/home/elicer/M-LLM/data/sc_amazon.csv"

train_dataset, val_dataset, test_dataset = datamodule.preprare_dataset(path)
# train_dataset = train_dataset.map(lambda samples: tokenizer(samples["text"]), batched=True)
# val_dataset = val_dataset.map(lambda samples: tokenizer(samples["text"]), batched=True)
# test_dataset = test_dataset.map(lambda samples: tokenizer(samples["text"]), batched=True)

In [None]:
from datasets import Dataset

text_data = {'text': train_dataset['text']}
train_dataset = Dataset.from_dict(text_data)
train_dataset = train_dataset.map(lambda samples: tokenizer(samples["text"]), batched=True)

In [None]:
import wandb
import os

wandb.login()
os.environ["WANDB_PROJECT"]="M-LLM"

In [None]:
import transformers

args = transformers.TrainingArguments(
    num_train_epochs = 1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    fp16=True,
    logging_steps=10,
    output_dir="outputs",
    optim="sgd",
    report_to="wandb",
    run_name="Mistral-sc",
    lr_scheduler_type="cosine",
)

trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=train_dataset,
    args=args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()
wandb.finish()

In [None]:
peft_model.push_to_hub("JD97/SC")