# Install unsloth

In [1]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

# Import Necessary Libraries

In [6]:
from unsloth import FastLanguageModel, PatchDPOTrainer, is_bfloat16_supported
from trl import DPOTrainer, DPOConfig
import torch
import ast
import re
import textwrap
import gc
import os
from datasets import load_dataset

# Load and Configure the Model

In [5]:
max_seq_length = 2048
dtype = None
load_in_4bit = True
model_name = "unsloth/Qwen2.5-Coder-1.5B-Instruct"

PatchDPOTrainer()

print(f"Loading {model_name}...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 42,
    use_rslora = False,
    loftq_config = None,
)

Loading unsloth/Qwen2.5-Coder-1.5B-Instruct...
==((====))==  Unsloth 2025.12.7: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/265 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/632 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Unsloth 2025.12.7 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


# Load  Dataset

In [8]:
dataset_file = "/content/groq_dpo_dataset .jsonl"

print(f"Loading Dataset from: {dataset_file}...")

def formatting_func(example):
    return {
        "prompt": example["prompt"],
        "chosen": example["chosen"],
        "rejected": example["rejected"],
    }

dataset = load_dataset("json", data_files=dataset_file, split="train")

original_columns = dataset.column_names
dataset = dataset.map(formatting_func, remove_columns=original_columns)
print(" Dataset loaded successfully!")

Loading Dataset from: /content/groq_dpo_dataset .jsonl...


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/584 [00:00<?, ? examples/s]

 Dataset loaded successfully!


# Train the Model

In [9]:
os.environ["WANDB_DISABLED"] = "true"

gc.collect()
torch.cuda.empty_cache()

training_args = DPOConfig(
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 8,

    warmup_ratio = 0.1,
    num_train_epochs = 3,
    learning_rate = 5e-6,
    fp16 = not is_bfloat16_supported(),
    bf16 = is_bfloat16_supported(),
    logging_steps = 10,
    optim = "adamw_8bit",
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = 42,
    output_dir = "dpo_outputs",

    gradient_checkpointing = True,

    beta = 0.1,
    max_length = max_seq_length,
    max_prompt_length = 1024,
)

dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = training_args,
)

print("Starting Training (Low Memory Mode)...")
dpo_trainer.train()

Extracting prompt in train dataset (num_proc=6):   0%|          | 0/584 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=6):   0%|          | 0/584 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=6):   0%|          | 0/584 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting Training (Low Memory Mode)...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 584 | Num Epochs = 3 | Total steps = 219
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 18,464,768 of 1,562,179,072 (1.18% trained)


Step,Training Loss,rewards / chosen,rewards / rejected,rewards / accuracies,rewards / margins,logps / chosen,logps / rejected,logits / chosen,logits / rejected,eval_logits / chosen,eval_logits / rejected,nll_loss
10,0.6924,0.000392,-0.001077,0.4625,0.001469,-65.476486,-130.995941,-3.365234,-2.852645,0,0,0
20,0.6832,0.00479,-0.015228,0.8875,0.020018,-64.342285,-131.785019,-3.367497,-2.865988,No Log,No Log,No Log
30,0.6311,0.036644,-0.093169,0.925,0.129812,-59.659016,-128.410431,-3.362586,-2.86691,No Log,No Log,No Log
40,0.5598,0.069242,-0.222786,0.9375,0.292028,-62.984581,-128.510757,-3.29875,-2.814882,No Log,No Log,No Log
50,0.4712,0.118849,-0.403243,0.9625,0.522092,-64.856018,-135.986618,-3.469383,-2.877916,No Log,No Log,No Log
60,0.4062,0.147057,-0.573671,0.9625,0.720729,-70.332489,-142.748779,-3.476149,-2.87746,No Log,No Log,No Log
70,0.3137,0.209382,-0.82559,1.0,1.034972,-65.857788,-144.850555,-3.548933,-2.939309,No Log,No Log,No Log
80,0.3097,0.217565,-0.891142,0.925,1.108707,-63.911713,-139.254425,-3.426116,-2.880621,No Log,No Log,No Log
90,0.2512,0.269299,-1.120925,0.925,1.390224,-63.967224,-140.326126,-3.439044,-2.879664,No Log,No Log,No Log
100,0.1908,0.320691,-1.40169,0.975,1.722381,-65.849205,-158.010757,-3.476383,-2.932626,No Log,No Log,No Log


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


TrainOutput(global_step=219, training_loss=0.26168618218539513, metrics={'train_runtime': 1543.5255, 'train_samples_per_second': 1.135, 'train_steps_per_second': 0.142, 'total_flos': 0.0, 'train_loss': 0.26168618218539513, 'epoch': 3.0})

# Test the Model

In [11]:
def extract_definitions(code_str):
    names = []
    try:
        tree = ast.parse(code_str)
        for node in tree.body:
            if isinstance(node, ast.ClassDef):
                names.append(node.name)
            elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                if not node.name.startswith('_'):
                    names.append(node.name)
    except SyntaxError:
        pass
    return names

In [12]:
def check_google_style_strict(docstring, node_type=None, function_name=""):
    errors = []
    if not docstring:
        return ["Docstring is empty or missing."]

    if "Args:" in docstring:
        pattern = r"\s+[\w\*]+\s*\(.*\):\s+"
        if not re.search(pattern, docstring):
            errors.append("Bad 'Args' format. Expected indentation + 'name (type): description'.")

    is_class = (node_type == ast.ClassDef)
    is_init = (function_name == "__init__")

    if not is_class and not is_init:
        if "Returns:" not in docstring and "Yields:" not in docstring:
            errors.append("Missing 'Returns:' (or 'Yields:') section.")

    return errors

In [13]:
def extract_code_block(text):
    pattern = r"```(?:python)?\s*(.*?)\s*```"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return text.strip()

In [14]:
FastLanguageModel.for_inference(model)
def generate_verified_docstring(model, tokenizer, code_snippet, max_retries=3):
    real_names = extract_definitions(code_snippet)

    system_instruction = """You are a strict Python Code Editor.
    Your goal: Rewrite the provided code to insert Google-Style docstrings.

    ### CRITICAL RULES:
    1. Output the **FULL COMPLETE CODE** (imports, classes, functions). Do NOT summarize.
    2. Output ONLY valid Python code inside ```python``` blocks.
    3. Format Args: `param_name (type): description`.
    4. **Classes** do NOT need a 'Returns' section.
    5. `__init__` does NOT need a 'Returns' section.
    """

    messages = [
        {"role": "system", "content": system_instruction},
        {"role": "user", "content": f"Here is the code to document:\n\n{code_snippet}"}
    ]

    for attempt in range(max_retries):
        print(f"  Attempt {attempt + 1}...", end=" ")

        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to("cuda")

        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs,
                max_new_tokens=2048,
                do_sample=True,
                temperature=0.2,
                repetition_penalty=1.1,
                pad_token_id=tokenizer.eos_token_id
            )

        generated_tokens = outputs[0][inputs.shape[-1]:]
        full_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        generated_code = extract_code_block(full_output)

        if not generated_code:
            generated_code = full_output.strip()

        errors = []

        try:
            module = ast.parse(generated_code)
        except SyntaxError as e:
            errors.append(f"Syntax Error: {e}")

        if not errors:
            found_names = []
            for node in module.body:
                if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)):
                    found_names.append(node.name)
                    if node.name.startswith('_') and node.name != "__init__": continue

                    docstring_content = ast.get_docstring(node)

                    style_errors = check_google_style_strict(
                        docstring_content,
                        node_type=type(node),
                        function_name=node.name
                    )

                    if style_errors:
                        errors.append(f"In '{node.name}': {', '.join(style_errors)}")

            for name in real_names:
                if name not in found_names:
                    errors.append(f"Missing original definition '{name}'. YOU MUST RETURN THE FULL CODE.")

        if not errors:
            print("Success!")
            return generated_code
        else:
            print(f"Failed: {errors[0]}...")

            error_msg = f"Your code has errors: {errors}. \nCRITICAL: You must rewrite the **ENTIRE ORIGINAL CODE** with fixes. Do not output snippet only."

            messages.append({"role": "assistant", "content": full_output})
            messages.append({"role": "user", "content": error_msg})

    print("Max retries reached.")
    return generated_code

In [15]:
complex_test_code = """
def calculate_velocity(distance, time):
    if time == 0: raise ValueError("Time cannot be zero")
    return distance / time
"""
meta_code = """
from typing import Any, Dict

class SingletonMeta(type):
    _instances: Dict[Any, Any] = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            instance = super().__call__(*args, **kwargs)
            cls._instances[cls] = instance
        return cls._instances[cls]

class Database(metaclass=SingletonMeta):
    def __init__(self, connection_string: str):
        self.connection_string = connection_string
"""
async_code = """
import asyncio

class AsyncDatabaseConnection:
    def __init__(self, db_url: str):
        self.db_url = db_url
        self.connection = None

    async def __aenter__(self):
        print(f"Connecting to {self.db_url}...")
        await asyncio.sleep(0.1)  # Simulate connection
        self.connection = "Connected"
        return self.connection

    async def __aexit__(self, exc_type, exc, tb):
        print("Closing connection...")
        await asyncio.sleep(0.1)  # Simulate closing
        self.connection = None
"""
dynamic_code = """
from typing import Any

class LazyProxy:
    def __init__(self, target: Any):
        self._target = target

    def __getattr__(self, name: str) -> Any:
        print(f"Intercepting access to {name}")
        attr = getattr(self._target, name)

        if callable(attr):
            def wrapper(*args, **kwargs):
                print(f"Calling method: {name}")
                return attr(*args, **kwargs)
            return wrapper
        return attr
"""
chatgpt = """
def normalize_text(text, lower=True, remove_punctuation=False):
    if lower:
        text = text.lower()

    if remove_punctuation:
        import string
        text = text.translate(str.maketrans("", "", string.punctuation))

    return text.strip()
    """
print("Starting Strict Verification...\n")
print(generate_verified_docstring(model, tokenizer, complex_test_code))
print(generate_verified_docstring(model, tokenizer, meta_code))
print(generate_verified_docstring(model, tokenizer, async_code))
print(generate_verified_docstring(model, tokenizer, chatgpt))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Starting Strict Verification...

  Attempt 1... Failed: In 'calculate_velocity': Docstring is empty or missing....
  Attempt 2... Failed: In 'calculate_velocity': Docstring is empty or missing....
  Attempt 3... Failed: In 'calculate_velocity': Docstring is empty or missing....
Max retries reached.
"""
Calculate the velocity given the distance and time.

Args:
    distance (float): The distance traveled in meters.
    time (float): The time taken in seconds.

Returns:
    float: The calculated velocity in meters per second.
"""

def calculate_velocity(distance, time):
    if time == 0:
        raise ValueError("Time cannot be zero")
    return distance / time
  Attempt 1... Failed: In 'SingletonMeta': Docstring is empty or missing....
  Attempt 2... Success!
from typing import Any, Dict

class SingletonMeta(type):
    """A metaclass that ensures only one instance of a class exists."""

    _instances: Dict[Any, Any] = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cl