# SFT of Hisoka-0.5B On Swissprot QA

Sequences Author: [Khairi Abidi](https://github.com/abidikhairi)

This notebook demonstrates supervised fine-tuning on protein QA dataset.

Key Features:

- Memory Efficient: LoRA for consumer GPUs.
- SFT: Supervised finetuning.

The model learns to generate model/functional protein sequences.

## Installation and Setup
Install the required packages for continued pretraining with memory-efficient techniques.

In [2]:
%env WANDB_PROJECT=Unsloth-SFT

env: WANDB_PROJECT=Unsloth-SFT


## Connect to 3rd party services¶

- **WandB**: for experiment tracking.
- **HuggingFace Hub**: for model checkpoints uploading.

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGING_FACE_TOKEN")
wandb_token = user_secrets.get_secret("WANDB_API_KEY")

In [4]:
!wandb login {wandb_token}
!huggingface-cli login --token {hf_token}

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `hf`CLI if you want to set the git credential as well.
Token is valid (permission: write).
The token `KAGGLE_TOKEN` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `KAGGLE_TOKEN`


## GPU Environment Detection
Verify GPU availability and display hardware specifications for optimal training configuration.

In [5]:
import torch

# Verify CUDA availability and display GPU specifications
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    # Display current GPU details for training optimization
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    # Provide guidance for enabling GPU in Colab
    print("⚠️  No GPU available. This notebook requires a GPU for efficient training.")
    print("In Colab: Runtime → Change runtime type → Hardware accelerator → GPU")

CUDA available: True
Number of GPUs: 2
Current GPU: 0
GPU name: Tesla T4
GPU memory: 15.8 GB


## Core Library Imports¶
Import essential libraries for pre-training, model configuration, and experiment tracking.

In [52]:
# Model and tokenization
from unsloth import FastLanguageModel

# Training and Setup
from unsloth import (
    is_bfloat16_supported
)
from trl import (
    SFTTrainer,
    SFTConfig
)

# Dataset handling
from datasets import load_dataset

# Utils
import ast
import re

In [12]:
model_name = 'khairi/Hisoka-1B'
max_seq_len = 1024
dtype = torch.float16
load_in_4bit = True

print(f'Loading model: {model_name}')
print(f'Max input length: {max_seq_len}')
print(f'Model dtype: {dtype}')
print(f'Is 4bit quantization supported: {load_in_4bit}')

Loading model: khairi/Hisoka-1B
Max input length: 1024
Model dtype: torch.float16
Is 4bit quantization supported: True


In [13]:
# Load model with automatic device mapping
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_len,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

# Ensure tokenizer has proper padding token for batch processing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

==((====))==  Unsloth 2025.9.9: Fast Qwen3 patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [14]:
print(f"✅ Model loaded successfully!")
print(f"📊 Model parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
print(f"🧮 Quantized parameters: ~{sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.1f}M")
model.print_trainable_parameters()

✅ Model loaded successfully!
📊 Model parameters: ~1796.6M
🧮 Quantized parameters: ~685.8M
trainable params: 139,460,608 || all params: 2,482,365,440 || trainable%: 5.6181


In [15]:
def compute_model_size(model):
    n_params = 0
    for p in model.parameters():
        n_params += p.nelement() * p.element_size()
    for p in model.buffers():
        n_params += p.nelement() * p.element_size()

    return n_params / (1024 ** 3)

print(f"📊 Model size : {compute_model_size(model):.2f} GB")

📊 Model size : 2.97 GB


## Dataset Setup

🧬 Protein instructions: Rooted in the biosciences, this component presents 505K instructions across five distinct categories of tasks. These tasks aim to predict the structure, function, and activity of proteins, and facilitate protein design based on textual directives.

In [28]:
protein_start = '<protein_start>'
protein_end = '<protein_end>'
eos_token = tokenizer.eos_token

print("✅ Protein start token: {}".format(protein_start))
print("✅ Protein end token: {}".format(protein_end))
print("✅ Protein eos token: {}".format(eos_token))

✅ Protein start token: <protein_start>
✅ Protein end token: <protein_end>
✅ Protein eos token: <|im_end|>


In [53]:
def filter_dataset_example(example):
    return ast.literal_eval(example['metadata'])['seq_len'] <= 256

print("✅ Dataset filtering function defined")

✅ Dataset filtering function defined


In [71]:
def extract_protein(text):
    match = re.search(r"<protein_start>\s*(.*?)\s*<protein_end>", text)
    seq = match.group(1).strip()
    seq =  ' '.join(list(seq))

    return f'{protein_start} {seq} {protein_end}'

print("✅ Protein extraction function defined")

✅ Protein extraction function defined


In [76]:
def format_dataset_example(example):
    """Convert Entry into conversation format (ChatML)"""
    instruction = example['instruction']
    inputs = example['input']
    output = example['output']
    
    user_input = f'{instruction}\n{inputs}'
    output = output.replace('```\n', f'{protein_start} ')
    output = output.replace('\n```', f' {protein_end}')
    seq = extract_protein(output)
    output = re.sub(r"<protein_start>.*?<protein_end>", seq, output, flags=re.DOTALL)

    # one-round conversation
    messages = [
        {'role': "user", 'content': user_input},
        {'role': "assistant", 'content': output}
    ]
    
    return {
        "messages": messages,
    }

print("✅ Dataset formatting functions defined")

✅ Dataset formatting functions defined


In [98]:
# Load and preprocess Swissprot training dataset
print("🔄 Loading Protein Question/Answer dataset...")
dataset = load_dataset("zjunlp/Mol-Instructions", "Protein-oriented Instructions", trust_remote_code=True)

dataset = dataset['protein_design']

# Apply conversation formatting to all examples
dataset = dataset.filter(filter_dataset_example) \
    .map(format_dataset_example) \
    .select_columns('messages')

# Split dataset into train/test
dataset = dataset.train_test_split(test_size=128)

train_data = dataset['train']
valid_data = dataset['test'].select(range(128)) # Pick 128 protein for evaluation

print(f"✅ Dataset loaded and processed!")
print(f"📊 Training examples: {len(train_data):,}")
print(f"📊 Validation examples: {len(valid_data):,}")
print(f"🎯 Sample protein: {train_data[0]['messages']}")
print(f"🎯 Sample protein (tokenized): {tokenizer.apply_chat_template(train_data[0]['messages'], tokenize=False)}")

🔄 Loading Protein Question/Answer dataset...
✅ Dataset loaded and processed!
📊 Training examples: 57,111
📊 Validation examples: 128
🎯 Sample protein: [{'content': 'Generate a protein sequence that meets the functional requirements while minimizing unwanted side effects.\n1. The stability of the protein-Mg(2+) complex should be optimized to maximize enzymatic activity.\n2. For general function, the protein need meet that Seems to function as a house-cleaning enzyme that removes non-canonical purine nucleotides from the nucleotide pool, thus preventing their incorporation into DNA/RNA and avoiding chromosomal lesions.; Pyrophosphatase that catalyzes the hydrolysis of nucleoside triphosphates to their monophosphate derivatives, with a high preference for the non-canonical purine nucleotides XTP (xanthosine triphosphate), dITP (deoxyinosine triphosphate) and ITP\n3. The protein with the reaction dITP + H2O = dIMP + diphosphate + H(+), the Proton acceptor must be optimized to ensure efficie

## Training Setup
Configure training parameters optimized for instruction following with memory constraints.

In [103]:
training_args = SFTConfig(
    assistant_only_loss=True,
    
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 8,

    # Use warmup_ratio and num_train_epochs for longer runs!
    max_steps = 120,
    # warmup_steps = 10,
    warmup_ratio = 0.1,
    # num_train_epochs = 1,

    learning_rate = 6e-5,

    fp16 = not is_bfloat16_supported(),
    bf16 = is_bfloat16_supported(),
    logging_steps = 20,
    eval_steps = 20,
    save_steps = 20,
    eval_strategy = 'steps',
    save_total_limit = 3,
    load_best_model_at_end = True,
    optim = "adamw_8bit",
    weight_decay = 0.01,
    lr_scheduler_type = "cosine",
    
    output_dir = "/tmp/outputs",
    run_name = 'hisoka-1b-sft-mol-instruct',
    report_to = "none", # Use this for WandB etc

    # Push to Hub, set true in production
    push_to_hub=True,
    hub_model_id='khairi/Hisoka-1B-Instruct'
)

In [111]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data.select(range(5)),
    eval_dataset=valid_data,
    args=training_args,
    formatting_func=lambda x: tokenizer.apply_chat_template(x, return_tensors='pt')
)

UndefinedError: dict object has no element 0

In [112]:
SFTTrainer??

[0;31mInit signature:[0m
[0mSFTTrainer[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdata_collator[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtrain_dataset[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meval_dataset[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprocessing_class[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcompute_loss_func[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcompute_metrics[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcallbacks[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0moptimizer_cls_and_kwargs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpreprocess_logits_for_metrics[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0