# Kimi-Audio Model Fine-Tuning with FSDP and LoRA
## User-Friendly Version: Just provide your ZIP file path!

This notebook handles:
- Automatic ZIP extraction and dataset formatting
- FSDP (Fully Sharded Data Parallel) training
- LoRA (Low-Rank Adaptation) for efficient fine-tuning
- Multi-GPU support
- Automatic checkpoint saving

## 1. Install Dependencies

In [1]:
! pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
! pip install -q transformers datasets accelerate bitsandbytes peft
! pip install -q wandb tensorboard soundfile librosa
! pip install -q einops sentencepiece protobuf

print("‚úÖ All dependencies installed successfully!")

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25h‚úÖ All dependencies installed successfully!


## 2. Import Libraries

In [2]:
import os
import json
import zipfile
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Union
from dataclasses import dataclass, field

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
from datasets import Dataset as HFDataset

import numpy as np
import soundfile as sf
import librosa

print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {transformers.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

PyTorch version: 2.9.0+cu128
Transformers version: 5.0.0
CUDA available: True
GPU count: 1
  GPU 0: Tesla T4


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 3. Configuration
### üëâ MODIFY THESE PARAMETERS

In [4]:
@dataclass
class Config:
    # ========== USER INPUT ==========
    # Path to your ZIP file (local or Google Drive)
    zip_file_path: str = "/content/drive/MyDrive/kimi_data/kimi_audio_dataset.zip"  # CHANGE THIS!
    
    # ========== MODEL SETTINGS ==========
    model_name: str = "moonshotai/Kimi-Audio-7B-Instruct"  # or local path
    max_length: int = 2048
    
    # ========== LoRA SETTINGS ==========
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ])
    
    # ========== FSDP SETTINGS ==========
    use_fsdp: bool = True
    fsdp_sharding_strategy: str = "FULL_SHARD"  # or "SHARD_GRAD_OP", "NO_SHARD"
    fsdp_min_num_params: int = 1e6  # Minimum parameters for wrapping
    
    # ========== TRAINING SETTINGS ==========
    output_dir: str = "./kimi_audio_finetuned"
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 2e-4
    warmup_steps: int = 100
    logging_steps: int = 10
    save_steps: int = 500
    eval_steps: int = 500
    save_total_limit: int = 3
    fp16: bool = True
    bf16: bool = False  # Use bf16 if available (A100, H100)
    
    # ========== DATASET SETTINGS ==========
    train_split_ratio: float = 0.9
    audio_sample_rate: int = 16000
    
    # ========== PATHS (Auto-generated) ==========
    extracted_dir: str = "./extracted_dataset"
    processed_dir: str = "./processed_dataset"

config = Config()
print("Configuration loaded successfully!")
print(json.dumps(config.__dict__, indent=2, default=str))

Configuration loaded successfully!
{
  "zip_file_path": "/content/drive/MyDrive/kimi_data/kimi_audio_dataset.zip",
  "model_name": "moonshotai/Kimi-Audio-7B-Instruct",
  "max_length": 2048,
  "use_lora": true,
  "lora_r": 16,
  "lora_alpha": 32,
  "lora_dropout": 0.05,
  "lora_target_modules": [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj"
  ],
  "use_fsdp": true,
  "fsdp_sharding_strategy": "FULL_SHARD",
  "fsdp_min_num_params": 1000000.0,
  "output_dir": "./kimi_audio_finetuned",
  "num_train_epochs": 3,
  "per_device_train_batch_size": 2,
  "per_device_eval_batch_size": 2,
  "gradient_accumulation_steps": 4,
  "learning_rate": 0.0002,
  "warmup_steps": 100,
  "logging_steps": 10,
  "save_steps": 500,
  "eval_steps": 500,
  "save_total_limit": 3,
  "fp16": true,
  "bf16": false,
  "train_split_ratio": 0.9,
  "audio_sample_rate": 16000,
  "extracted_dir": "./extracted_dataset",
  "processed_dir": "./processed_dataset"
}


## 4. ZIP Extraction and Dataset Preparation

In [5]:
class DatasetExtractor:
    """Handles ZIP extraction and dataset formatting"""
    
    def __init__(self, config: Config):
        self.config = config
        self.extracted_path = Path(config.extracted_dir)
        self.processed_path = Path(config.processed_dir)
        
    def extract_zip(self):
        """Extract ZIP file to the extracted directory"""
        print(f"üì¶ Extracting ZIP file: {self.config.zip_file_path}")
        
        if not os.path.exists(self.config.zip_file_path):
            raise FileNotFoundError(f"ZIP file not found: {self.config.zip_file_path}")
        
        # Clean previous extraction
        if self.extracted_path.exists():
            shutil.rmtree(self.extracted_path)
        self.extracted_path.mkdir(parents=True, exist_ok=True)
        
        # Extract
        with zipfile.ZipFile(self.config.zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(self.extracted_path)
        
        print(f"‚úÖ Extraction complete: {self.extracted_path}")
        self._print_directory_structure()
        
    def _print_directory_structure(self):
        """Print the extracted directory structure"""
        print("\nüìÅ Directory structure:")
        for root, dirs, files in os.walk(self.extracted_path):
            level = root.replace(str(self.extracted_path), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            sub_indent = ' ' * 2 * (level + 1)
            for file in files[:5]:  # Show first 5 files
                print(f"{sub_indent}{file}")
            if len(files) > 5:
                print(f"{sub_indent}... and {len(files) - 5} more files")
    
    def detect_format(self) -> str:
        """Auto-detect dataset format"""
        print("\nüîç Detecting dataset format...")
        
        # Check for common formats
        formats = {
            'jsonl': list(self.extracted_path.rglob('*.jsonl')),
            'json': list(self.extracted_path.rglob('*.json')),
            'csv': list(self.extracted_path.rglob('*.csv')),
            'txt': list(self.extracted_path.rglob('*.txt')),
        }
        
        for fmt, files in formats.items():
            if files:
                print(f"‚úÖ Detected format: {fmt.upper()}")
                print(f"   Found {len(files)} {fmt} file(s)")
                return fmt
        
        # Check for audio files
        audio_files = list(self.extracted_path.rglob('*.wav')) + \
                     list(self.extracted_path.rglob('*.mp3')) + \
                     list(self.extracted_path.rglob('*.flac'))
        
        if audio_files:
            print(f"‚úÖ Detected audio dataset with {len(audio_files)} files")
            return 'audio'
        
        raise ValueError("Could not detect dataset format. Please check your ZIP structure.")
    
    def load_dataset(self) -> List[Dict]:
        """Load dataset based on detected format"""
        fmt = self.detect_format()
        data = []
        
        if fmt == 'jsonl':
            for file in self.extracted_path.rglob('*.jsonl'):
                with open(file, 'r', encoding='utf-8') as f:
                    for line in f:
                        data.append(json.loads(line.strip()))
        
        elif fmt == 'json':
            for file in self.extracted_path.rglob('*.json'):
                with open(file, 'r', encoding='utf-8') as f:
                    content = json.load(f)
                    if isinstance(content, list):
                        data.extend(content)
                    else:
                        data.append(content)
        
        elif fmt == 'audio':
            # For audio datasets, create metadata from audio files
            audio_files = list(self.extracted_path.rglob('*.wav')) + \
                         list(self.extracted_path.rglob('*.mp3')) + \
                         list(self.extracted_path.rglob('*.flac'))
            
            for audio_file in audio_files:
                # Try to find corresponding text/transcript
                txt_file = audio_file.with_suffix('.txt')
                transcript = ""
                if txt_file.exists():
                    with open(txt_file, 'r', encoding='utf-8') as f:
                        transcript = f.read().strip()
                
                data.append({
                    'audio_path': str(audio_file),
                    'text': transcript,
                    'filename': audio_file.name
                })
        
        print(f"\nüìä Loaded {len(data)} samples")
        if data:
            print(f"Sample data keys: {list(data[0].keys())}")
            print(f"First sample: {json.dumps(data[0], indent=2, default=str)[:500]}...")
        
        return data
    
    def format_for_training(self, data: List[Dict]) -> List[Dict]:
        """Format data for Kimi-Audio training"""
        print("\nüîÑ Formatting data for training...")
        formatted_data = []
        
        for item in data:
            # Flexible key mapping
            text_keys = ['text', 'transcript', 'transcription', 'label', 'target']
            audio_keys = ['audio', 'audio_path', 'file', 'path']
            
            text = None
            for key in text_keys:
                if key in item:
                    text = item[key]
                    break
            
            audio_path = None
            for key in audio_keys:
                if key in item:
                    audio_path = item[key]
                    break
            
            if text:  # At minimum we need text
                formatted_item = {
                    'text': str(text),
                }
                if audio_path:
                    formatted_item['audio_path'] = str(audio_path)
                
                formatted_data.append(formatted_item)
        
        print(f"‚úÖ Formatted {len(formatted_data)} samples")
        return formatted_data

# Execute extraction
extractor = DatasetExtractor(config)
extractor.extract_zip()
raw_data = extractor.load_dataset()
formatted_data = extractor.format_for_training(raw_data)

print(f"\n‚úÖ Dataset ready: {len(formatted_data)} samples")

üì¶ Extracting ZIP file: /content/drive/MyDrive/kimi_data/kimi_audio_dataset.zip
‚úÖ Extraction complete: extracted_dataset

üìÅ Directory structure:
extracted_dataset/
  kimi_audio_dataset/
    dataset.jsonl
    Wav/
      049.wav
      050.wav
      040.wav
      007.wav
      082.wav
      ... and 72 more files

üîç Detecting dataset format...
‚úÖ Detected format: JSONL
   Found 2 jsonl file(s)

üìä Loaded 76 samples
Sample data keys: ['audio', 'text']
First sample: {
  "audio": "/content/Wav/035.wav",
  "text": "will never mind the pies. As Mrs. Rachel says, pies they always were and pies they always will be, world without end amen."
}...

üîÑ Formatting data for training...
‚úÖ Formatted 76 samples

‚úÖ Dataset ready: 76 samples


## 5. Custom Dataset Class

In [6]:
class KimiAudioDataset(Dataset):
    """Custom dataset for Kimi-Audio fine-tuning"""
    
    def __init__(self, data: List[Dict], tokenizer, max_length: int = 2048):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': input_ids.clone()  # For causal LM
        }

print("‚úÖ Dataset class defined")

‚úÖ Dataset class defined


In [7]:
from huggingface_hub import login

# Replace with your actual token from https://huggingface.co/settings/tokens
login(token="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")

In [8]:
! pip install flash-attn --no-build-isolation
print("‚úÖ FlashAttention installed successfully!")


Collecting flash-attn
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.4/8.4 MB[0m [31m130.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=253780426 sha256=4e2f9e39313266b1544b68138b15b91ee6221eccf14f7902b7c6620351340810
  Stored in directory: /root/.cache/pip/wheels/3d/59/46/f282c12c73dd4bb3c2e3fe199f1a0d0f8cec06df0cccfeee27
Successfully built flash-attn
Installing collected packages: flash-attn
Successfully installed flash-attn-2.8.3
‚úÖ FlashAttention installed successfully!


## 6. Load Model and Tokenizer

In [9]:
print("üîÑ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    config.model_name,
    trust_remote_code=True,
    use_fast=False
)

# Set padding token if not exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("‚úÖ Tokenizer loaded")
print(f"Vocab size: {tokenizer.vocab_size}")
# print(f"Special tokens: {tokenizer.special_tokens_map}")

print("\nüîÑ Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16 if config.fp16 else torch.float32,
    device_map='auto',
    low_cpu_mem_usage=True
)
# model = AutoModelForCausalLM.from_pretrained(
#     config.model_name,
#     trust_remote_code=True,
#     dtype=torch.float16 if config.fp16 else torch.float32,
#     device_map="auto",
#     # low_cpu_mem_usage=True,
#     attn_implementation="eager"   # üëà disables flash attention
# )


print("‚úÖ Model loaded")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B")

üîÑ Loading tokenizer...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json: 0.00B [00:00, ?B/s]

configuration_moonshot_kimia.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct:
- configuration_moonshot_kimia.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenization_kimia.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct:
- tokenization_kimia.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tiktoken.model:   0%|          | 0.00/2.56M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

‚úÖ Tokenizer loaded
Vocab size: 152064

üîÑ Loading base model...


`torch_dtype` is deprecated! Use `dtype` instead!


modeling_moonshot_kimia.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct:
- modeling_moonshot_kimia.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


RuntimeError: flash attention must be installed

In [6]:
!nvidia-smi

Wed Feb 11 11:53:49 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

## 7. Apply LoRA

In [26]:
if config.use_lora:
    print("\nüîÑ Applying LoRA configuration...")
    
    # Prepare model for k-bit training (if using quantization)
    # model = prepare_model_for_kbit_training(model)
    
    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        target_modules=config.lora_target_modules,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    print("‚úÖ LoRA applied successfully")
else:
    print("‚ö†Ô∏è LoRA is disabled - training full model")


üîÑ Applying LoRA configuration...


NameError: name 'model' is not defined

## 8. Prepare Datasets

In [None]:
# Split data
split_idx = int(len(formatted_data) * config.train_split_ratio)
train_data = formatted_data[:split_idx]
eval_data = formatted_data[split_idx:]

print(f"üìä Dataset split:")
print(f"  Training samples: {len(train_data)}")
print(f"  Evaluation samples: {len(eval_data)}")

# Create dataset objects
train_dataset = KimiAudioDataset(train_data, tokenizer, config.max_length)
eval_dataset = KimiAudioDataset(eval_data, tokenizer, config.max_length) if eval_data else None

print("‚úÖ Datasets created")

## 9. Configure FSDP Training Arguments

In [None]:
# FSDP configuration
fsdp_config = None
if config.use_fsdp and torch.cuda.device_count() > 1:
    fsdp_config = {
        "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],  # Adjust based on model
        "fsdp_sharding_strategy": config.fsdp_sharding_strategy,
        "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
        "fsdp_backward_prefetch": "BACKWARD_PRE",
        "fsdp_cpu_ram_efficient_loading": True,
        "fsdp_state_dict_type": "FULL_STATE_DICT",
    }
    print("‚úÖ FSDP configuration prepared")
else:
    print("‚ö†Ô∏è FSDP disabled (single GPU or disabled in config)")

# Training arguments
training_args = TrainingArguments(
    output_dir=config.output_dir,
    num_train_epochs=config.num_train_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    warmup_steps=config.warmup_steps,
    logging_dir=f"{config.output_dir}/logs",
    logging_steps=config.logging_steps,
    save_steps=config.save_steps,
    eval_steps=config.eval_steps,
    save_total_limit=config.save_total_limit,
    fp16=config.fp16,
    bf16=config.bf16,
    evaluation_strategy="steps" if eval_dataset else "no",
    save_strategy="steps",
    load_best_model_at_end=True if eval_dataset else False,
    metric_for_best_model="loss" if eval_dataset else None,
    greater_is_better=False,
    ddp_find_unused_parameters=False,
    report_to=["tensorboard"],
    remove_unused_columns=False,
    fsdp=fsdp_config if fsdp_config else None,
)

print("‚úÖ Training arguments configured")

## 10. Initialize Trainer

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # We're doing causal LM, not masked LM
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

print("‚úÖ Trainer initialized and ready")
print(f"\nTraining configuration:")
print(f"  Model: {config.model_name}")
print(f"  LoRA: {config.use_lora}")
print(f"  FSDP: {config.use_fsdp}")
print(f"  Epochs: {config.num_train_epochs}")
print(f"  Batch size: {config.per_device_train_batch_size}")
print(f"  Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"  Effective batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps * torch.cuda.device_count()}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Output directory: {config.output_dir}")

## 11. Start Training üöÄ

In [None]:
print("\n" + "="*50)
print("üöÄ STARTING TRAINING")
print("="*50 + "\n")

# Train the model
train_result = trainer.train()

print("\n" + "="*50)
print("‚úÖ TRAINING COMPLETED")
print("="*50 + "\n")

# Save metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)

print(f"Training metrics:")
for key, value in metrics.items():
    print(f"  {key}: {value}")

## 12. Save Final Model

In [None]:
print("\nüíæ Saving final model...")

# Save model
final_model_path = os.path.join(config.output_dir, "final_model")
trainer.save_model(final_model_path)

# Save tokenizer
tokenizer.save_pretrained(final_model_path)

print(f"‚úÖ Model saved to: {final_model_path}")

# If using LoRA, save adapter separately
if config.use_lora:
    lora_path = os.path.join(config.output_dir, "lora_adapter")
    model.save_pretrained(lora_path)
    print(f"‚úÖ LoRA adapter saved to: {lora_path}")

## 13. Run Evaluation (Optional)

In [None]:
if eval_dataset:
    print("\nüìä Running final evaluation...")
    
    eval_metrics = trainer.evaluate()
    
    print(f"\nEvaluation metrics:")
    for key, value in eval_metrics.items():
        print(f"  {key}: {value}")
    
    trainer.log_metrics("eval", eval_metrics)
    trainer.save_metrics("eval", eval_metrics)
else:
    print("\n‚ö†Ô∏è No evaluation dataset available")

## 14. Test Inference

In [None]:
print("\nüß™ Testing inference with fine-tuned model...")

# Test prompt
test_prompt = "Hello, how are you?"

# Tokenize
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

# Generate
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"\nPrompt: {test_prompt}")
print(f"Generated: {generated_text}")

## 15. Cleanup (Optional)

In [None]:
# Uncomment to clean up temporary files
# print("\nüßπ Cleaning up temporary files...")
# shutil.rmtree(config.extracted_dir, ignore_errors=True)
# print("‚úÖ Cleanup complete")

## 16. Summary Report

In [None]:
print("\n" + "="*60)
print("üìã TRAINING SUMMARY REPORT")
print("="*60)

print(f"\n‚úÖ Fine-tuning completed successfully!\n")

print(f"Dataset Information:")
print(f"  ZIP file: {config.zip_file_path}")
print(f"  Training samples: {len(train_data)}")
print(f"  Evaluation samples: {len(eval_data)}")

print(f"\nModel Information:")
print(f"  Base model: {config.model_name}")
print(f"  LoRA enabled: {config.use_lora}")
if config.use_lora:
    print(f"  LoRA rank: {config.lora_r}")
    print(f"  LoRA alpha: {config.lora_alpha}")
print(f"  FSDP enabled: {config.use_fsdp}")

print(f"\nTraining Configuration:")
print(f"  Epochs: {config.num_train_epochs}")
print(f"  Batch size: {config.per_device_train_batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Precision: {'FP16' if config.fp16 else 'BF16' if config.bf16 else 'FP32'}")

print(f"\nOutput Locations:")
print(f"  Model directory: {final_model_path}")
if config.use_lora:
    print(f"  LoRA adapter: {lora_path}")
print(f"  Checkpoints: {config.output_dir}")
print(f"  Logs: {config.output_dir}/logs")

print("\n" + "="*60)
print("üéâ You can now use your fine-tuned model!")
print("="*60)

print(f"\nTo load the model later:")
print(f"```python")
print(f"from transformers import AutoModelForCausalLM, AutoTokenizer")
print(f"model = AutoModelForCausalLM.from_pretrained('{final_model_path}')")
print(f"tokenizer = AutoTokenizer.from_pretrained('{final_model_path}')")
print(f"```")