# Training

In [1]:
from typing import Optional
from nebu.processors.decorate import processor
from nebu import Message
from nebu.processors.models import (
    V1Scale,
    V1ScaleDown,
    V1ScaleUp,
    V1ScaleZero,
)
from pydantic import BaseModel

In [2]:
class TrainingRequest(BaseModel):
    adapter_name: str
    dataset: str
    model: str = "unsloth/Qwen2.5-VL-7B-Instruct"
    max_length: int = 65536
    epochs: int = 1
    batch_size: int = 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    warmup_steps: int = 5
    logging_steps: int = 1
    save_steps: int = 5
    lora_alpha: int = 128
    lora_rank: int = 64
    lora_dropout: float = 0
    optimizer: str = "adamw_8bit"
    owner: Optional[str] = None


class TrainingResponse(BaseModel):
    loss: float
    train_steps_per_second: float
    train_samples_per_second: float
    train_runtime: float
    adapter_name: str
    adapter_uri: str


# TODO: add default scale
scale = V1Scale(
    up=V1ScaleUp(above_pressure=10, duration="5m"),
    down=V1ScaleDown(below_pressure=2, duration="10m"),
    zero=V1ScaleZero(duration="10m"),
)

setup_script = """
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip uninstall -y xformers
pip install -U xformers --index-url https://download.pytorch.org/whl/cu126
pip install unsloth trl
"""


@processor(
    image="pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel",
    setup_script=setup_script,
    scale=scale,
    accelerators=["1:A100_SXM"],
    platform="runpod",
)
def train_unsloth_sft(message: Message[TrainingRequest]) -> TrainingResponse:
    import time
    from unsloth import FastVisionModel, is_bf16_supported
    from unsloth.trainer import UnslothVisionDataCollator
    from trl import SFTTrainer, SFTConfig
    from nebu import (
        Bucket,
        ContainerConfig,
        Cache,
        Adapter,
        find_latest_checkpoint,
        is_allowed,
        oai_to_unsloth,
    )
    import requests
    import json

    print("message", message)
    training_request: TrainingRequest = message.content
    if not training_request:
        raise ValueError("No training request provided")

    container_config = ContainerConfig.from_env()
    print("container_config", container_config)

    cache = Cache()
    bucket = Bucket()

    print("loading model...")
    adapter_uri = f"{container_config.namespace_volume_uri}/adapters/{training_request.adapter_name}"
    time_start_load = time.time()
    model = None

    cache_key = f"adapters:{training_request.adapter_name}"
    print("checking cache for adapter", cache_key)
    val_raw = cache.get(cache_key)

    is_continue = False
    epochs_trained = 0
    if val_raw:
        adapter = Adapter.model_validate_json(val_raw)
        print("Found adapter: ", adapter)

        epochs_trained = adapter.epochs_trained

        if not is_allowed(adapter.owner, message.user_id, message.orgs):
            raise ValueError("You are not allowed to train this existing adapter")

        time_start = time.time()
        bucket.sync(adapter.uri, "/latest")
        print(f"Synced in {time.time() - time_start} seconds")

        model, tokenizer = FastVisionModel.from_pretrained(
            "/latest",
            load_in_4bit=False,
            use_gradient_checkpointing="unsloth",
        )
        is_continue = True
    if not model:
        print("Loading model from scratch")
        model, tokenizer = FastVisionModel.from_pretrained(
            training_request.model,
            load_in_4bit=False,  # Use 4bit to reduce memory use. False for 16bit LoRA.
            use_gradient_checkpointing="unsloth",  # True or "unsloth" for long context
        )

        print("getting peft model...")
        model = FastVisionModel.get_peft_model(
            model,
            finetune_vision_layers=True,  # False if not finetuning vision layers
            finetune_language_layers=True,  # False if not finetuning language layers
            finetune_attention_modules=True,  # False if not finetuning attention layers
            finetune_mlp_modules=True,  # False if not finetuning MLP layers
            r=training_request.lora_rank,  # The larger, the higher the accuracy, but might overfit
            lora_alpha=training_request.lora_alpha,  # Recommended alpha == r at least
            lora_dropout=training_request.lora_dropout,
            bias="none",
            random_state=3407,
            use_rslora=False,  # We support rank stabilized LoRA
            loftq_config=None,  # And LoftQ
            use_fast=True,
            # target_modules = "all-linear", # Optional now! Can specify a list if needed
        )
    print(f"Loaded model in {time.time() - time_start_load} seconds")

    print("Downloading dataset")
    time_start_download = time.time()
    response = requests.get(training_request.dataset)
    response.raise_for_status()  # optional: raises if request failed
    print(f"Downloaded dataset in {time.time() - time_start_download} seconds")

    # Decode and split into lines
    lines = response.content.decode("utf-8").splitlines()

    # Parse and convert each JSON line
    time_start_convert = time.time()
    converted_dataset = [
        oai_to_unsloth(json.loads(line)) for line in lines if line.strip()
    ]
    print(f"Converted dataset in {time.time() - time_start_convert} seconds")

    print(converted_dataset)

    FastVisionModel.for_training(model)  # Enable for training!

    train_epochs = epochs_trained + training_request.epochs

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=UnslothVisionDataCollator(model, tokenizer),  # Must use!
        train_dataset=converted_dataset,
        args=SFTConfig(
            per_device_train_batch_size=training_request.batch_size,
            gradient_accumulation_steps=training_request.gradient_accumulation_steps,
            warmup_steps=training_request.warmup_steps,
            # max_steps=training_request.max_steps,
            num_train_epochs=train_epochs,
            learning_rate=training_request.learning_rate,
            fp16=not is_bf16_supported(),
            bf16=is_bf16_supported(),
            logging_steps=training_request.logging_steps,
            optim=training_request.optimizer,
            weight_decay=training_request.weight_decay,
            lr_scheduler_type="linear",
            seed=3407,
            output_dir="outputs",
            report_to="none",  # For Weights and Biases
            # You MUST put the below items for vision finetuning:
            remove_unused_columns=False,
            dataset_text_field="",
            dataset_kwargs={"skip_prepare_dataset": True},
            dataset_num_proc=4,
            max_seq_length=training_request.max_length,
        ),
    )

    time_start_train = time.time()
    trainer_stats = trainer.train(resume_from_checkpoint=is_continue)
    print(trainer_stats)
    print(f"Trained in {time.time() - time_start_train} seconds")

    latest_checkpoint = find_latest_checkpoint("outputs")
    print("latest checkpoint")
    if latest_checkpoint:
        print("Copying latest checkpoint to bucket")
        bucket.copy(
            latest_checkpoint,
            adapter_uri,
        )

    adapter = Adapter(
        name=training_request.adapter_name,
        uri=adapter_uri,
        owner=message.content.owner if message.content.owner else message.user_id,  # type: ignore
        base_model=training_request.model,
        epochs_trained=train_epochs,
        last_trained=int(time.time()),
        lora_rank=training_request.lora_rank,
        lora_alpha=training_request.lora_alpha,
        lora_dropout=training_request.lora_dropout,
    )
    cache.set(cache_key, adapter.model_dump_json())

    return TrainingResponse(
        loss=trainer_stats.training_loss,
        train_steps_per_second=trainer_stats.metrics["train_steps_per_second"],
        train_samples_per_second=trainer_stats.metrics["train_samples_per_second"],
        train_runtime=trainer_stats.metrics["train_runtime"],
        adapter_name=training_request.adapter_name,
        adapter_uri=adapter_uri,
    )

[DEBUG Decorator Init] @processor decorating function 'train_unsloth_sft'
[DEBUG Decorator] Determining execution environment...
[DEBUG Helper] Checking if running in Jupyter...
[DEBUG Helper] is_jupyter_notebook: IPython class name: <class 'ipykernel.zmqshell.ZMQInteractiveShell'>
[DEBUG Helper] is_jupyter_notebook: Jupyter detected (ZMQInteractiveShell).
[DEBUG Decorator] Jupyter environment detected.
[DEBUG Helper] Attempting to get notebook execution history...
[DEBUG Helper] get_notebook_executed_code: Retrieved 2 history entries.
[DEBUG Helper] get_notebook_executed_code: Total history source length: 7889
[DEBUG Decorator] Retrieved notebook history (length: 7889).
[DEBUG Decorator] No manually included objects specified.
[DEBUG Decorator] Validating signature and type hints for train_unsloth_sft...
[DEBUG Decorator] Raw type hints: {'message': <class 'nebu.processors.models.Message[TrainingRequest]'>, 'return': <class '__main__.TrainingResponse'>}
[DEBUG Decorator] Parameter 'me

In [6]:
training_req = TrainingRequest(
    adapter_name="foo1",
    dataset="https://storage.googleapis.com/orign/testdata/nebu/clinton.jsonl",
)

In [7]:
train_unsloth_sft.send(training_req.model_dump())

{'success': True,
 'stream_id': '1744395616208-0',
 'message_id': '12r7LQkPqTy6h3LYcJgKWd'}

In [8]:
train_unsloth_sft.delete()

# Inference

In [5]:
from typing import List
from nebu.processors.decorate import processor
from nebu import Message
from nebu.chatx.openai import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionChoice,
    ChatCompletionResponseMessage,
    Logprobs,
)

In [None]:
setup_script = """
pip install torch torchvision torchaudio qwen-vl-utils --index-url https://download.pytorch.org/whl/cu126
pip uninstall -y xformers
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu126
pip install tiktoken unsloth qwen-vl-utils transformers
"""


@processor(
    image="pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel",
    setup_script=setup_script,
    accelerators=["1:A100_SXM"],
    platform="runpod",
)
def infer_qwen_vl(
    message: Message[ChatCompletionRequest],
) -> ChatCompletionResponse:
    import time

    full_time = time.time()

    import uuid
    from unsloth import FastVisionModel
    from qwen_vl_utils import process_vision_info
    from nebu import (
        Bucket,
        ContainerConfig,
        Adapter,
        Cache,
        is_allowed,
    )
    from transformers import AutoProcessor

    base_model_id = "unsloth/Qwen2.5-VL-7B-Instruct"

    if not hasattr(infer_qwen_vl, "model"):
        print("loading model...")
        time_start_load = time.time()
        base_model = FastVisionModel.from_pretrained(
            base_model_id,
            load_in_4bit=False,
        )
        print(f"Loaded model in {time.time() - time_start_load} seconds")
        model_processor = AutoProcessor.from_pretrained(base_model_id)
        FastVisionModel.for_inference(base_model)

        # we can just store our model on the function
        infer_qwen_vl.base_model = base_model  # type: ignore
        infer_qwen_vl.model_processor = model_processor  # type: ignore

    else:
        base_model = infer_qwen_vl.base_model  # type: ignore
        model_processor = infer_qwen_vl.model_processor  # type: ignore

    if not hasattr(infer_qwen_vl, "cache"):
        print("Creating cache")
        cache = Cache()
        infer_qwen_vl.cache = cache  # type: ignore
    else:
        cache = infer_qwen_vl.cache  # type: ignore

    if not hasattr(infer_qwen_vl, "adapters"):
        print("Creating adapters")
        adapters: List[Adapter] = []
        infer_qwen_vl.adapters = adapters  # type: ignore
    else:
        adapters = infer_qwen_vl.adapters  # type: ignore

    print("message", message)
    training_request = message.content
    if not training_request:
        raise ValueError("No training request provided")

    print("content", message.content)

    container_config = ContainerConfig.from_env()
    print("container_config", container_config)

    content = message.content
    if not content:
        raise ValueError("No content provided")

    adapter_hot_start = time.time()
    val_raw = cache.get(f"adapters:{content.model}")
    if val_raw:
        print("val_raw", val_raw)
        val = Adapter.model_validate_json(val_raw)

        if not is_allowed(val.owner, message.user_id, message.orgs):
            raise ValueError("You are not allowed to use this adapter")

        if not val.base_model == base_model_id:
            raise ValueError(
                "The base model of the adapter does not match the model you are trying to use"
            )

        loaded = False
        for adapter in adapters:
            if val.name == content.model and val.created_at == adapter.created_at:
                loaded = True
                print("adapter already loaded", content.model)
                break
        print(f"Adapter hot start: {time.time() - adapter_hot_start} seconds")

        if not loaded:
            bucket = Bucket()
            print("copying adapter", val.uri, f"./adapters/{content.model}")
            time_start = time.time()
            bucket.copy(val.uri, f"./adapters/{content.model}")
            print(f"Copied in {time.time() - time_start} seconds")

            print("loading adapter", content.model)
            base_model.load_adapter(
                f"./adapters/{content.model}", adapter_name=content.model
            )
            infer_qwen_vl.adapters.append(content.model)  # type: ignore
            print("loaded adapter", content.model)

    else:
        raise ValueError(f"Adapter '{content.model}' not found")

    base_model.set_adapter(content.model)

    conent_dict = content.model_dump()
    messages = conent_dict["messages"]

    # Preparation for inference
    print("preparing inputs")
    inputs_start = time.time()
    text = model_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = model_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    print("inputs", inputs)
    print(f"Inputs prepared in {time.time() - inputs_start} seconds")

    # Inference: Generation of the output
    generated_ids = base_model.generate(**inputs, max_new_tokens=content.max_tokens)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generation_start = time.time()
    output_text = model_processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )
    print("output_text", output_text)
    print(f"Generation took {time.time() - generation_start} seconds")

    # Build the Pydantic model, referencing your enumerations and classes
    response = ChatCompletionResponse(
        id=str(uuid.uuid4()),
        created=int(time.time()),
        model=content.model,
        object="chat.completion",
        choices=[
            ChatCompletionChoice(
                index=0,
                finish_reason="stop",  # or another appropriate reason
                message=ChatCompletionResponseMessage(
                    role="assistant", content=output_text[0]
                ),
                # Stub logprobs; in reality, you'd fill these from your model if you have them
                logprobs=Logprobs(content=[]),
            )
        ],
        service_tier=None,
        system_fingerprint=None,
        usage=None,
    )
    print(f"Total time: {time.time() - full_time} seconds")

    return response