In [1]:
# =============================================================================
# CELL 1: Environment Setup and Dependencies
# Install required packages and set up environment variables for memory optimization
# =============================================================================

import os
import torch
import json
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
import gc

# Memory optimization environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("üöÄ APOLLO-2B MEDICAL FINE-TUNING PIPELINE")
print("="*50)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

üöÄ APOLLO-2B MEDICAL FINE-TUNING PIPELINE
PyTorch version: 2.5.0+cu118
CUDA available: True
GPU: NVIDIA GeForce RTX 4070
GPU Memory: 12.9 GB


In [2]:
# =============================================================================
# CELL 2: Load and Validate Processed Medical Q&A Dataset
# Load the structured medical Q&A data and validate format
# =============================================================================

def load_medical_qa_dataset(file_path="medical_qa_training_data.jsonl", max_samples=10000):
    """Load and validate medical Q&A dataset"""
    
    print("üìä Loading Medical Q&A Dataset...")
    
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= max_samples:  # Limit for 6GB VRAM
                break
            record = json.loads(line.strip())
            data.append(record)
    
    print(f"‚úÖ Loaded {len(data)} Q&A records")
    return data

# Load dataset
qa_data = load_medical_qa_dataset(max_samples=25000)  # Adjust based on your VRAM

# Validation: Check dataset structure
print("\nüîç DATASET VALIDATION:")
print("="*30)
sample = qa_data[0]
required_fields = ['question', 'answer', 'training_text']
for field in required_fields:
    status = "‚úÖ" if field in sample else "‚ùå"
    print(f"{status} {field}: {'Present' if field in sample else 'Missing'}")

print(f"\nüìù SAMPLE Q&A PAIR:")
print(f"Question: {sample['question'][:100]}...")
print(f"Answer: {sample['answer'][:200]}...")
print(f"Training format length: {len(sample['training_text'])} characters")

print(f"\nüìà DATASET STATISTICS:")
avg_question_len = sum(len(item['question']) for item in qa_data) / len(qa_data)
avg_answer_len = sum(len(item['answer']) for item in qa_data) / len(qa_data)
print(f"Average question length: {avg_question_len:.0f} characters")
print(f"Average answer length: {avg_answer_len:.0f} characters")
print(f"Total training samples: {len(qa_data)}")

üìä Loading Medical Q&A Dataset...
‚úÖ Loaded 25000 Q&A records

üîç DATASET VALIDATION:
‚úÖ question: Present
‚úÖ answer: Present
‚úÖ training_text: Present

üìù SAMPLE Q&A PAIR:
Question: What is A simplified scleral reinforcement technique. and what should I know about it?...
Answer: Explanation: A simplified scleral reinforcement technique. A simplified scleral reinforcement technique performed on 52 eyes with myopic degeneration prevented further visual loss by strengthening of ...
Training format length: 928 characters

üìà DATASET STATISTICS:
Average question length: 128 characters
Average answer length: 1095 characters
Total training samples: 25000


In [3]:
# =============================================================================
# CELL 3: Convert to HuggingFace Dataset Format
# Prepare the Q&A data in the format required for SFTTrainer
# =============================================================================

def prepare_training_dataset(qa_data):
    """Convert Q&A data to HuggingFace Dataset format"""
    
    print("üîÑ Preparing Training Dataset...")
    
    # Extract training texts for the model
    training_data = []
    for item in qa_data:
        training_data.append({
            "text": item["training_text"],
            "question": item["question"],
            "answer": item["answer"]
        })
    
    # Create HuggingFace Dataset
    dataset = Dataset.from_list(training_data)
    
    print(f"‚úÖ Created training dataset with {len(dataset)} samples")
    return dataset

# Prepare dataset
train_dataset = prepare_training_dataset(qa_data)

# Validation: Check dataset format
print("\nüîç TRAINING DATASET VALIDATION:")
print("="*35)
print(f"Dataset size: {len(train_dataset)}")
print(f"Dataset features: {train_dataset.features}")

# Show sample training text
sample_text = train_dataset[0]['text']
print(f"\nüìù SAMPLE TRAINING TEXT:")
print(sample_text[:300] + "..." if len(sample_text) > 300 else sample_text)

# Check text lengths for sequence optimization
text_lengths = [len(item['text']) for item in train_dataset]
print(f"\nüìä TEXT LENGTH STATISTICS:")
print(f"Min length: {min(text_lengths)}")
print(f"Max length: {max(text_lengths)}")
print(f"Average length: {sum(text_lengths)/len(text_lengths):.0f}")

üîÑ Preparing Training Dataset...
‚úÖ Created training dataset with 25000 samples

üîç TRAINING DATASET VALIDATION:
Dataset size: 25000
Dataset features: {'text': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None)}

üìù SAMPLE TRAINING TEXT:
Question: What is A simplified scleral reinforcement technique. and what should I know about it?
Answer: Explanation: A simplified scleral reinforcement technique. A simplified scleral reinforcement technique performed on 52 eyes with myopic degeneration prevented further visual loss by strengthenin...

üìä TEXT LENGTH STATISTICS:
Min length: 516
Max length: 3599
Average length: 1242


In [4]:
# =============================================================================
# CELL 4: Load Apollo-2B Model with 4-bit Quantization
# Setup the model for efficient training on 6GB VRAM
# =============================================================================

def setup_apollo_model():
    """Load Apollo-2B with QLoRA configuration for 6GB VRAM"""
    
    print("ü§ñ Setting up Apollo-2B Model...")
    
    # Quantization config for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_storage=torch.uint8
    )
    
    print("Loading model from HuggingFace...")
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        "FreedomIntelligence/Apollo-2B",
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        "FreedomIntelligence/Apollo-2B",
        trust_remote_code=True
    )
    
    # Setup tokenizer
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    print("‚úÖ Model and tokenizer loaded successfully")
    return model, tokenizer

# Load model and tokenizer
model, tokenizer = setup_apollo_model()

# Validation: Check model loading
print("\nüîç MODEL VALIDATION:")
print("="*25)
print(f"Model type: {type(model).__name__}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Pad token: {tokenizer.pad_token}")

# Check GPU memory usage
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated() / 1e9
    print(f"GPU memory allocated: {memory_allocated:.2f} GB")

ü§ñ Setting up Apollo-2B Model...
Loading model from HuggingFace...


config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/134M [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

Error while downloading from https://huggingface.co/FreedomIntelligence/Apollo-2B/resolve/4049fc07e95649c91e86be4458faf935bdc106b8/model-00002-of-00003.safetensors: HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.
Trying to resume download...
Error while downloading from https://huggingface.co/FreedomIntelligence/Apollo-2B/resolve/4049fc07e95649c91e86be4458faf935bdc106b8/model-00001-of-00003.safetensors: HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.
Trying to resume download...


model-00001-of-00003.safetensors:   4%|3         | 189M/4.91G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   5%|4         | 231M/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

‚úÖ Model and tokenizer loaded successfully

üîç MODEL VALIDATION:
Model type: GemmaForCausalLM
Model device: cuda:0
Model dtype: torch.float16
Tokenizer vocab size: 256000
Pad token: <eos>
GPU memory allocated: 2.07 GB


In [5]:
# =============================================================================
# CELL 5: Setup LoRA Configuration
# Configure parameter-efficient fine-tuning to reduce memory usage
# =============================================================================

def setup_lora(model):
    """Configure LoRA for parameter-efficient fine-tuning"""
    
    print("üîß Setting up LoRA Configuration...")
    
    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)
    
    # LoRA configuration optimized for 6GB VRAM
    lora_config = LoraConfig(
        r=16,                    # Rank - balance between performance and memory
        lora_alpha=32,          # Alpha parameter
        target_modules=[        # Target modules for LoRA
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    # Apply LoRA to model
    model = get_peft_model(model, lora_config)
    
    print("‚úÖ LoRA configuration applied")
    return model

# Apply LoRA configuration
model = setup_lora(model)

# Validation: Check LoRA parameters
print("\nüîç LORA VALIDATION:")
print("="*22)
total_params = model.num_parameters()
trainable_params = model.num_parameters(only_trainable=True)
trainable_percentage = (trainable_params / total_params) * 100

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Trainable percentage: {trainable_percentage:.2f}%")
print(f"Memory reduction: ~{100 - trainable_percentage:.1f}%")

# Print trainable modules
model.print_trainable_parameters()

üîß Setting up LoRA Configuration...
‚úÖ LoRA configuration applied

üîç LORA VALIDATION:
Total parameters: 2,525,784,064
Trainable parameters: 19,611,648
Trainable percentage: 0.78%
Memory reduction: ~99.2%
trainable params: 19,611,648 || all params: 2,525,784,064 || trainable%: 0.7765


In [6]:
# =============================================================================
# CELL 6: Training Configuration
# Setup training parameters optimized for 6GB VRAM and medical data
# =============================================================================

def setup_training_args(output_dir="./apollo2b-medical-qa"):
    """Configure training arguments for medical fine-tuning"""
    
    print("‚öôÔ∏è Configuring Training Parameters...")
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        
        # Memory optimization for 6GB VRAM
        per_device_train_batch_size=2,      # Small batch size
        gradient_accumulation_steps=8,       # Effective batch size = 8
        gradient_checkpointing=True,         # Trade compute for memory
        dataloader_pin_memory=False,         # Reduce memory transfer
        
        # Training parameters
        num_train_epochs=2,                  # Number of epochs
        learning_rate=1e-4,                  # Learning rate
        warmup_steps=100,                    # Warmup steps
        logging_steps=25,                    # Log every 25 steps
        save_steps=500,                      # Save every 500 steps
        eval_steps=500,                      # Evaluation frequency
        
        # Optimizer settings
        optim="paged_adamw_8bit",           # Memory-efficient optimizer
        lr_scheduler_type="cosine",         # Learning rate scheduler
        weight_decay=0.01,                  # Weight decay
        max_grad_norm=1.0,                  # Gradient clipping
        
        # Precision and efficiency
        fp16=True,                          # Mixed precision training
        remove_unused_columns=False,
        
        # Monitoring and saving
        report_to=None,                     # Disable wandb for now
        save_total_limit=3,                 # Keep only 3 checkpoints
        load_best_model_at_end=False,
    )
    
    print("‚úÖ Training arguments configured")
    return training_args

# Setup training arguments
training_args = setup_training_args()

# Validation: Check training configuration
print("\nüîç TRAINING CONFIGURATION:")
print("="*30)
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"Number of epochs: {training_args.num_train_epochs}")
print(f"Total training steps: {(len(train_dataset) // training_args.per_device_train_batch_size // training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")
print(f"Optimizer: {training_args.optim}")
print(f"Mixed precision: {training_args.fp16}")

‚öôÔ∏è Configuring Training Parameters...
‚úÖ Training arguments configured

üîç TRAINING CONFIGURATION:
Batch size: 2
Gradient accumulation: 8
Effective batch size: 16
Learning rate: 0.0001
Number of epochs: 2
Total training steps: 3124
Optimizer: OptimizerNames.PAGED_ADAMW_8BIT
Mixed precision: True


In [7]:
# =============================================================================
# CELL 7 (Fix): Initialize transformers.Trainer with labels for causal LM
# =============================================================================

from transformers import Trainer, default_data_collator

def tokenize_and_add_labels(examples):
    """Tokenize and add labels for causal LM"""
    # Tokenize text
    outputs = tokenizer(
        examples["text"],
        truncation=True,
        padding="longest",
        max_length=512
    )
    # For causal LM, labels are identical to input_ids
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

def start_medical_training_v2(model, tokenizer, train_dataset, training_args):
    """Initialize Trainer for causal LM and start fine-tuning"""
    
    print("üèãÔ∏è Initializing Trainer with labels for causal LM...")
    
    # Tokenize and add labels
    tokenized_ds = train_dataset.map(
        tokenize_and_add_labels,
        batched=True,
        remove_columns=["text", "question", "answer"]
    )
    print(f"‚úÖ Tokenized dataset with columns: {tokenized_ds.column_names}")
    
    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_ds,
        data_collator=default_data_collator,
        tokenizer=tokenizer
    )
    
    print("‚úÖ Trainer initialized")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"GPU memory before training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    
    print("\nüöÄ Starting Medical Fine-tuning...")
    trainer.train()
    print("‚úÖ Training completed!")
    
    return trainer

# Start training
trainer = start_medical_training_v2(model, tokenizer, train_dataset, training_args)

if torch.cuda.is_available():
    print(f"GPU memory after training: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

üèãÔ∏è Initializing Trainer with labels for causal LM...


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 1}.


‚úÖ Tokenized dataset with columns: ['input_ids', 'attention_mask', 'labels']
‚úÖ Trainer initialized
GPU memory before training: 3.20 GB

üöÄ Starting Medical Fine-tuning...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
25,1.9221
50,1.1296
75,1.0053
100,0.9337
125,0.9461
150,0.9094
175,0.9046
200,0.9113
225,0.8991
250,0.9156


‚úÖ Training completed!
GPU memory after training: 3.23 GB


In [8]:
# =============================================================================
# CELL 8: Save Fine-tuned Medical Model
# Save the trained model and tokenizer for future use
# =============================================================================

def save_medical_model(trainer, tokenizer, save_path="./apollo2b-medical-qa-final"):
    """Save the fine-tuned medical model"""
    
    print("üíæ Saving Fine-tuned Medical Model...")
    
    # Save model
    trainer.save_model(save_path)
    
    # Save tokenizer
    tokenizer.save_pretrained(save_path)
    
    print(f"‚úÖ Model saved to: {save_path}")
    
    # Validation: Check saved files
    import os
    saved_files = os.listdir(save_path)
    print(f"\nüìÅ Saved files: {saved_files}")
    
    return save_path

# Save the model
model_path = save_medical_model(trainer, tokenizer)

print("\nüéâ MEDICAL MODEL TRAINING COMPLETE!")
print("="*40)
print(f"‚úÖ Model successfully fine-tuned on {len(train_dataset)} medical Q&A pairs")
print(f"‚úÖ Model saved to: {model_path}")
print(f"‚úÖ Ready for medical question answering!")

üíæ Saving Fine-tuned Medical Model...
‚úÖ Model saved to: ./apollo2b-medical-qa-final

üìÅ Saved files: ['adapter_config.json', 'adapter_model.safetensors', 'README.md', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'tokenizer_config.json', 'training_args.bin']

üéâ MEDICAL MODEL TRAINING COMPLETE!
‚úÖ Model successfully fine-tuned on 25000 medical Q&A pairs
‚úÖ Model saved to: ./apollo2b-medical-qa-final
‚úÖ Ready for medical question answering!
