In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from datasets import Dataset
from pymongo import MongoClient
import torch

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, 
    bnb_4bit_use_double_quant=True,
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "DeepSeek-AI/DeepSeek-R1-Distill-Qwen-32B",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="sdpa", 
)

In [None]:
peft_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

In [None]:
def format_dataset(example):
    instruction = example["instruction"]
    input_text = example.get("input", "")
    output = example["output"]
    
    text = f"### Instruction:\n{instruction}"
    if input_text is not None and input_text.strip():
        text += f"\n\n### Input:\n{input_text}"
    text += f"\n\n### Response:\n{output}"
    
    return {"text": text}

In [None]:
client = MongoClient('mongodb://incerca:incerca@ceva.ip.incearca.0:22/')
db = client['db']
collection = db['data']

In [None]:
mongo_data = []
for doc in collection.find():
    mongo_data.append({
        "instruction": doc["instruction"],
        "input": doc.get("user", ""),
        "output": doc["assistent"],
    })

In [None]:
dataset = Dataset.from_list(mongo_data)
dataset = dataset.map(format_dataset, remove_columns=dataset.column_names)

In [None]:
args = TrainingArguments(
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 16,
    learning_rate = 3e-5,
    num_train_epochs = 3,
    bf16 = True,
    optim = "paged_adamw_8bit",
    logging_steps = 10,
    save_steps = 500,
    output_dir = "outputs",
    seed = 42,
)