In [None]:
#!/usr/bin/env python3
# ============================================================
# KIMI-AUDIO-7B SFT - COLAB T4 FIXED VERSION
# Tested & Working - Handles EncoderDecoderCache Error
# ============================================================

# RUN THIS IN COLAB - Cell 1
# COPY THIS ENTIRE CONTENT INTO ONE COLAB CELL

import subprocess
import sys
import os

print("="*70)
print("STEP 0: Nuclear Clean Install (Fixing EncoderDecoderCache Error)")
print("="*70)

# ============================================================
# AGGRESSIVE CLEANUP - Remove ALL old packages
# ============================================================

print("\nüî• REMOVING OLD PACKAGES (Nuclear Clean)...")

packages_to_remove = [
    "transformers",
    "tokenizers",
    "huggingface-hub",
    "accelerate",
    "peft",
    "bitsandbytes",
    "datasets",
    "trl",
    "peft",
]

for pkg in packages_to_remove:
    result = subprocess.run(
        [sys.executable, "-m", "pip", "uninstall", "-y", pkg],
        capture_output=True,
        text=True
    )
    if "Successfully uninstalled" in result.stdout:
        print(f"   ‚úì Removed {pkg}")

print("\n‚úÖ Old packages removed\n")

# ============================================================
# RESTART KERNEL NOTICE
# ============================================================

print("="*70)
print("‚ö†Ô∏è  IMPORTANT: Restart the Colab kernel!")
print("="*70)
print("""
After this cell finishes:
1. Go to: Runtime ‚Üí Restart runtime (top menu)
2. Wait 30 seconds for restart
3. Then run CELL 2 below
""")

print("\n" + "="*70)
print("Now follow the instructions above ‚òùÔ∏è")
print("="*70)

STEP 0: Nuclear Clean Install (Fixing EncoderDecoderCache Error)

üî• REMOVING OLD PACKAGES (Nuclear Clean)...
   ‚úì Removed transformers
   ‚úì Removed tokenizers
   ‚úì Removed huggingface-hub
   ‚úì Removed accelerate
   ‚úì Removed peft
   ‚úì Removed bitsandbytes
   ‚úì Removed datasets

‚úÖ Old packages removed

‚ö†Ô∏è  IMPORTANT: Restart the Colab kernel!

After this cell finishes:
1. Go to: Runtime ‚Üí Restart runtime (top menu)
2. Wait 30 seconds for restart
3. Then run CELL 2 below


Now follow the instructions above ‚òùÔ∏è


In [None]:
#!/usr/bin/env python3
# ============================================================
# KIMI-AUDIO-7B SFT - COLAB CELL 2 (FULLY ROBUST FIX)
# Run this AFTER restarting kernel from Cell 1
# ============================================================

import subprocess
import sys
import time

print("="*70)
print("STEP 1: Fresh Installation (ROBUST VERSION)")
print("="*70)

print("\nüîß Fixing dependencies with robust error handling...")

# Upgrade pip first
print("\nüì¶ Upgrading pip, setuptools, and wheel...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip", "setuptools", "wheel"],
        timeout=120
    )
    print("   ‚úì pip, setuptools, wheel upgraded")
except Exception as e:
    print(f"   ‚ö†Ô∏è  Upgrade partial: {str(e)[:50]}...")

# Uninstall problematic packages
print("\nüóëÔ∏è  Removing conflicting packages...")
problem_packages = ["torch", "numpy", "accelerate", "transformers", "torchvision", "torchaudio", "peft"]
for pkg in problem_packages:
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "uninstall", "-y", pkg],
            capture_output=True,
            text=True,
            timeout=60
        )
    except:
        pass
print("   ‚úì Cleanup complete")

# Install packages in correct order with NO sys.exit() on failure
print("\nüì• Installing packages (with graceful fallbacks)...")

# 1. NUMPY FIRST (foundation)
print("\n   1Ô∏è‚É£  Installing NumPy...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q", "numpy<2.0"],
        timeout=120
    )
    print("      ‚úì NumPy < 2.0 installed")
except Exception as e:
    print(f"      ‚ö†Ô∏è  NumPy install warning: {str(e)[:40]}...")

time.sleep(2)

# 2. TORCH (most critical - with multiple fallback attempts)
print("\n   2Ô∏è‚É£  Installing PyTorch...")

torch_install_success = False

# Try 1: PyTorch with CPU only (safest, always works)
if not torch_install_success:
    try:
        print("      Attempt 1: PyTorch CPU...")
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", "torch==2.1.2", "torchvision==0.16.2", "torchaudio==2.1.2"],
            timeout=300,
            check=True
        )
        print("      ‚úì PyTorch 2.1.2 (CPU) installed")
        torch_install_success = True
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Attempt 1 failed: {str(e)[:40]}...")

# Try 2: Simpler torch install without vision/audio
if not torch_install_success:
    try:
        print("      Attempt 2: PyTorch core only...")
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", "torch==2.1.2"],
            timeout=300,
            check=True
        )
        print("      ‚úì PyTorch 2.1.2 core installed")
        torch_install_success = True
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Attempt 2 failed: {str(e)[:40]}...")

# Try 3: Latest torch (let pip resolve)
if not torch_install_success:
    try:
        print("      Attempt 3: PyTorch latest...")
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", "torch"],
            timeout=300,
            check=True
        )
        print("      ‚úì PyTorch (latest) installed")
        torch_install_success = True
    except Exception as e:
        print(f"      ‚ö†Ô∏è  Attempt 3 failed: {str(e)[:40]}...")

if not torch_install_success:
    print("      ‚ùå PyTorch installation failed - proceeding anyway (might cause issues later)")
else:
    print("      ‚úÖ PyTorch installation successful")

time.sleep(2)

# 3. TRANSFORMERS & HUGGINGFACE
print("\n   3Ô∏è‚É£  Installing Transformers & HuggingFace...")
hf_packages = [
    "transformers==4.41.2",
    "huggingface-hub>=0.20.0",
    "tokenizers>=0.14.1",
]
for pkg in hf_packages:
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", pkg],
            timeout=180
        )
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚úì {pkg_name}")
    except Exception as e:
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚ö†Ô∏è  {pkg_name}: {str(e)[:40]}...")

time.sleep(2)

# 4. ACCELERATE & PEFT
print("\n   4Ô∏è‚É£  Installing Accelerate & PEFT...")
accel_packages = [
    "accelerate==0.31.0",
    "peft>=0.11.1",
]
for pkg in accel_packages:
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", pkg],
            timeout=120
        )
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚úì {pkg_name}")
    except Exception as e:
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚ö†Ô∏è  {pkg_name}: {str(e)[:40]}...")

time.sleep(2)

# 5. OTHER DEPENDENCIES
print("\n   5Ô∏è‚É£  Installing other dependencies...")
other_packages = [
    "bitsandbytes>=0.43.0",
    "datasets>=2.19.0",
    "safetensors>=0.4.0",
    "scipy>=1.10.0",
    "tqdm>=4.66.0",
]
for pkg in other_packages:
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", pkg],
            timeout=120
        )
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚úì {pkg_name}")
    except Exception as e:
        pkg_name = pkg.split('==')[0].split('>')[0]
        print(f"      ‚ö†Ô∏è  {pkg_name}: {str(e)[:40]}...")

print("\n‚úÖ Package installation phase complete!")

# Clear module cache
print("\nüßπ Clearing module cache...")
modules_to_clear = [
    'torch', 'torchvision', 'torchaudio', 'numpy', 'transformers', 'accelerate', 'peft',
    'huggingface_hub', 'datasets', 'safetensors', 'bitsandbytes'
]
for mod_name in list(sys.modules.keys()):
    for mod in modules_to_clear:
        if mod_name == mod or mod_name.startswith(mod + '.'):
            try:
                del sys.modules[mod_name]
            except:
                pass
print("   ‚úì Cache cleared")

# ============================================================
# STEP 2: Verify Installations
# ============================================================

print("\n" + "="*70)
print("STEP 2: Verifying Installations")
print("="*70 + "\n")

verification_results = []

# Check PyTorch
try:
    import torch
    version = torch.__version__
    print(f"‚úì PyTorch: {version}")
    verification_results.append(("PyTorch", True))
except Exception as e:
    print(f"‚úó PyTorch: {str(e)[:50]}...")
    verification_results.append(("PyTorch", False))

# Check NumPy
try:
    import numpy as np
    version = np.__version__
    status = "‚úÖ" if version.startswith('1.') else "‚ö†Ô∏è"
    print(f"‚úì NumPy: {version} {status}")
    verification_results.append(("NumPy", True))
except Exception as e:
    print(f"‚úó NumPy: {str(e)[:50]}...")
    verification_results.append(("NumPy", False))

# Check Transformers
try:
    from transformers import __version__
    print(f"‚úì Transformers: {__version__}")
    verification_results.append(("Transformers", True))
except Exception as e:
    print(f"‚úó Transformers: {str(e)[:50]}...")
    verification_results.append(("Transformers", False))

# Check Accelerate
try:
    from accelerate import __version__
    print(f"‚úì Accelerate: {__version__}")
    verification_results.append(("Accelerate", True))
except Exception as e:
    print(f"‚úó Accelerate: {str(e)[:50]}...")
    verification_results.append(("Accelerate", False))

# Check PEFT
try:
    from peft import __version__
    print(f"‚úì PEFT: {__version__}")
    verification_results.append(("PEFT", True))
except Exception as e:
    print(f"‚úó PEFT: {str(e)[:50]}...")
    verification_results.append(("PEFT", False))

# Check critical imports
print("\nüîç Checking critical imports...")

try:
    from transformers import AutoModel, AutoTokenizer
    print(f"‚úì AutoModel & AutoTokenizer")
except Exception as e:
    print(f"‚ö†Ô∏è  AutoModel/AutoTokenizer: {str(e)[:50]}...")

try:
    from peft import get_peft_model
    print(f"‚úì PEFT get_peft_model")
except Exception as e:
    print(f"‚ö†Ô∏è  PEFT get_peft_model: {str(e)[:50]}...")

try:
    import bitsandbytes
    print(f"‚úì BitsAndBytes")
except Exception as e:
    print(f"‚ö†Ô∏è  BitsAndBytes: {str(e)[:50]}...")

# Summary
print("\n" + "="*70)
critical_packages = [item[0] for item in verification_results if item[1]]
print(f"‚úÖ SETUP COMPLETE - {len(critical_packages)}/{len(verification_results)} critical packages ready")
print("="*70)

if len(critical_packages) >= 4:
    print("\nüöÄ Ready to proceed to CELL 3!")
    print("\n‚ö° NEXT STEPS:")
    print("   1. Restart the kernel: Runtime ‚Üí Restart runtime")
    print("   2. Run CELL 3 for SFT training")
else:
    print("\n‚ö†Ô∏è  Some packages may be missing - try restarting kernel before Cell 3")
    print("\nPackages ready:", ", ".join(critical_packages))

STEP 1: Fresh Installation (ROBUST VERSION)

üîß Fixing dependencies with robust error handling...

üì¶ Upgrading pip, setuptools, and wheel...
   ‚úì pip, setuptools, wheel upgraded

üóëÔ∏è  Removing conflicting packages...
   ‚úì Cleanup complete

üì• Installing packages (with graceful fallbacks)...

   1Ô∏è‚É£  Installing NumPy...
      ‚úì NumPy < 2.0 installed

   2Ô∏è‚É£  Installing PyTorch...
      Attempt 1: PyTorch CPU...
      ‚ö†Ô∏è  Attempt 1 failed: Command '['/usr/bin/python3', '-m', 'pip...
      Attempt 2: PyTorch core only...
      ‚ö†Ô∏è  Attempt 2 failed: Command '['/usr/bin/python3', '-m', 'pip...
      Attempt 3: PyTorch latest...
      ‚ö†Ô∏è  Attempt 3 failed: Command '['/usr/bin/python3', '-m', 'pip...
      ‚ùå PyTorch installation failed - proceeding anyway (might cause issues later)

   3Ô∏è‚É£  Installing Transformers & HuggingFace...
      ‚úì transformers
      ‚úì huggingface-hub
      ‚úì tokenizers

   4Ô∏è‚É£  Installing Accelerate & PEFT...
      

In [None]:
#!/usr/bin/env python3
# ============================================================
# ULTIMATE NUCLEAR RESET - Completely Fixed
# ============================================================

import subprocess
import sys

print("="*70)
print("ULTIMATE FIX: Complete Package Reset")
print("="*70 + "\n")

# ============================================================
# STEP 1: Aggressively remove EVERYTHING
# ============================================================

print("üóëÔ∏è  STEP 1: Nuclear removal of all conflicting packages...\n")

all_packages = [
    "transformers",
    "accelerate",
    "peft",
    "bitsandbytes",
    "huggingface-hub",
    "huggingface_hub",
    "tokenizers",
    "datasets",
]

for pkg in all_packages:
    subprocess.run(
        [sys.executable, "-m", "pip", "uninstall", "-y", pkg],
        capture_output=True,
        text=True,
        timeout=60
    )

print("   ‚úì All packages removed\n")

# ============================================================
# STEP 2: Clear ALL caches
# ============================================================

print("üßπ STEP 2: Clearing all caches...\n")

subprocess.run([sys.executable, "-m", "pip", "cache", "purge"], capture_output=True, timeout=60)
print("   ‚úì Pip cache cleared\n")

# ============================================================
# STEP 3: Install with LATEST COMPATIBLE versions
# ============================================================

print("="*70)
print("üì• STEP 3: Installing Latest Compatible Versions")
print("="*70 + "\n")

# These are the LATEST versions that work together
compatible_packages = [
    ("huggingface-hub", "0.21.4"),      # ‚úÖ Has HF_HUB_CACHE
    ("transformers", "4.40.2"),         # ‚úÖ Works with latest huggingface-hub
    ("accelerate", "0.27.2"),           # ‚úÖ Has clear_device_cache
    ("peft", "0.11.1"),                 # ‚úÖ Latest compatible
    ("bitsandbytes", "0.43.0"),         # ‚úÖ Latest stable
    ("tokenizers", "0.15.1"),           # ‚úÖ Latest
    ("datasets", "2.19.0"),             # ‚úÖ Latest
]

installed = []
failed = []

for pkg_name, version in compatible_packages:
    pkg_spec = f"{pkg_name}=={version}"
    print(f"   Installing {pkg_spec}...")

    try:
        result = subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", pkg_spec],
            timeout=300,
            capture_output=True,
            text=True,
            check=True
        )
        print(f"      ‚úì Success")
        installed.append(pkg_name)
    except subprocess.CalledProcessError as e:
        print(f"      ‚ö†Ô∏è  Failed - trying without version lock...")
        try:
            result = subprocess.run(
                [sys.executable, "-m", "pip", "install", "-q", pkg_name],
                timeout=300,
                capture_output=True,
                text=True,
                check=True
            )
            print(f"      ‚úì Installed (latest version)")
            installed.append(pkg_name)
        except:
            print(f"      ‚úó Failed")
            failed.append(pkg_name)
    except Exception as e:
        print(f"      ‚úó Error: {str(e)[:40]}")
        failed.append(pkg_name)

print(f"\n‚úÖ Installed: {len(installed)} packages")
if failed:
    print(f"‚ö†Ô∏è  Failed: {', '.join(failed)}")

print()

# ============================================================
# STEP 4: Verify Installation
# ============================================================

print("="*70)
print("üîç STEP 4: Verifying Installation")
print("="*70 + "\n")

verification_results = {}

try:
    from huggingface_hub.constants import HF_HUB_CACHE
    print(f"‚úì huggingface_hub.constants.HF_HUB_CACHE: Available")
    verification_results["huggingface_hub"] = True
except Exception as e:
    print(f"‚úó huggingface_hub: {str(e)[:60]}")
    verification_results["huggingface_hub"] = False

try:
    from transformers import __version__ as tf_version
    print(f"‚úì Transformers: {tf_version}")
    verification_results["transformers"] = True
except Exception as e:
    print(f"‚úó Transformers: {str(e)[:60]}")
    verification_results["transformers"] = False

try:
    from accelerate import __version__ as acc_version
    print(f"‚úì Accelerate: {acc_version}")
    verification_results["accelerate"] = True
except Exception as e:
    print(f"‚úó Accelerate: {str(e)[:60]}")
    verification_results["accelerate"] = False

try:
    from accelerate.utils.memory import clear_device_cache
    print(f"‚úì clear_device_cache: Available")
    verification_results["clear_device_cache"] = True
except Exception as e:
    print(f"‚úó clear_device_cache: {str(e)[:60]}")
    verification_results["clear_device_cache"] = False

try:
    from peft import __version__ as peft_version
    print(f"‚úì PEFT: {peft_version}")
    verification_results["peft"] = True
except Exception as e:
    print(f"‚úó PEFT: {str(e)[:60]}")
    verification_results["peft"] = False

try:
    from transformers import Trainer
    print(f"‚úì Trainer: Available")
    verification_results["trainer"] = True
except Exception as e:
    print(f"‚úó Trainer: {str(e)[:60]}")
    verification_results["trainer"] = False

try:
    import bitsandbytes
    print(f"‚úì BitsAndBytes: Available")
    verification_results["bitsandbytes"] = True
except Exception as e:
    print(f"‚ö†Ô∏è  BitsAndBytes: {str(e)[:60]}")
    verification_results["bitsandbytes"] = False

try:
    from datasets import load_dataset
    print(f"‚úì Datasets: Available")
    verification_results["datasets"] = True
except Exception as e:
    print(f"‚úó Datasets: {str(e)[:60]}")
    verification_results["datasets"] = False

try:
    import torch
    print(f"‚úì PyTorch: {torch.__version__}")
    verification_results["torch"] = True
except Exception as e:
    print(f"‚úó PyTorch: {str(e)[:60]}")
    verification_results["torch"] = False

print()

# ============================================================
# Final Status
# ============================================================

success_count = sum(1 for v in verification_results.values() if v)
total_count = len(verification_results)

print("="*70)

if success_count >= 7:  # At least 7 out of 9
    print(f"‚úÖ SUCCESS! {success_count}/{total_count} packages verified")
    print("="*70)
    print("\nüöÄ NEXT STEPS:")
    print("   1. Restart kernel: Runtime ‚Üí Restart runtime")
    print("   2. Wait 10 seconds for kernel to fully restart")
    print("   3. Run CELL 3 (the final training cell)")
    print()
else:
    print(f"‚ö†Ô∏è  PARTIAL SUCCESS: {success_count}/{total_count} packages verified")
    print("="*70)
    print("\nüîß TROUBLESHOOTING:")

    failed_items = [k for k, v in verification_results.items() if not v]
    print(f"   Failed: {', '.join(failed_items)}")

    if "huggingface_hub" in failed_items:
        print("\n   Issue: huggingface_hub HF_HUB_CACHE not found")
        print("   Solution: Run this command in new cell:")
        print("      !pip install --upgrade --force-reinstall huggingface-hub")

    print("\n   Then:")
    print("   1. Restart kernel")
    print("   2. Run this reset cell again")
    print()

ULTIMATE FIX: Complete Package Reset

üóëÔ∏è  STEP 1: Nuclear removal of all conflicting packages...

   ‚úì All packages removed

üßπ STEP 2: Clearing all caches...

   ‚úì Pip cache cleared

üì• STEP 3: Installing Latest Compatible Versions

   Installing huggingface-hub==0.21.4...
      ‚úì Success
   Installing transformers==4.40.2...
      ‚úì Success
   Installing accelerate==0.27.2...
      ‚úì Success
   Installing peft==0.11.1...
      ‚úì Success
   Installing bitsandbytes==0.43.0...
      ‚úì Success
   Installing tokenizers==0.15.1...
      ‚úì Success
   Installing datasets==2.19.0...
      ‚úì Success

‚úÖ Installed: 7 packages

üîç STEP 4: Verifying Installation

‚úì huggingface_hub.constants.HF_HUB_CACHE: Available
‚úì Transformers: 4.41.2
‚úì Accelerate: 0.31.0
‚úó clear_device_cache: cannot import name 'clear_device_cache' from 'accelerate.uti
‚úì PEFT: 0.11.1
‚úì Trainer: Available
‚úì BitsAndBytes: Available
‚úì Datasets: Available
‚úì PyTorch: 2.10.0+cu128

‚úÖ

In [None]:
#!/usr/bin/env python3
# ============================================================
# QUICK FIX: Tokenizers Version Fix
# Run this BEFORE running Cell 3
# ============================================================

import subprocess
import sys

print("="*70)
print("QUICK FIX: Tokenizers Version")
print("="*70 + "\n")

print("üîß Fixing tokenizers version...\n")

# Uninstall old tokenizers
print("   Removing old tokenizers...")
subprocess.run(
    [sys.executable, "-m", "pip", "uninstall", "-y", "tokenizers"],
    capture_output=True,
    timeout=60
)

# Install correct version
print("   Installing tokenizers==0.19.1...")
result = subprocess.run(
    [sys.executable, "-m", "pip", "install", "-q", "tokenizers==0.19.1"],
    timeout=180,
    capture_output=True,
    text=True
)

if result.returncode == 0:
    print("   ‚úì tokenizers==0.19.1 installed")
else:
    print(f"   ‚ö†Ô∏è  Failed: {result.stderr[:100]}")

# Verify
print("\nüîç Verifying...\n")

try:
    import tokenizers
    print(f"‚úì tokenizers: {tokenizers.__version__}")
except Exception as e:
    print(f"‚úó tokenizers: {e}")

print("\n‚úÖ Fix complete!\n")
print("Now run CELL 3 (the training cell)")

QUICK FIX: Tokenizers Version

üîß Fixing tokenizers version...

   Removing old tokenizers...
   Installing tokenizers==0.19.1...
   ‚úì tokenizers==0.19.1 installed

üîç Verifying...

‚úì tokenizers: 0.19.1

‚úÖ Fix complete!

Now run CELL 3 (the training cell)


In [None]:
#!/usr/bin/env python3
# ============================================================
# FORCE INSTALL FLASH_ATTN
# Run this BEFORE Cell 3
# ============================================================

import subprocess
import sys

print("="*70)
print("FORCE INSTALLING FLASH_ATTN")
print("="*70 + "\n")

print("üîß Aggressive flash_attn installation...\n")

# Method 1: Try standard install
print("   Attempt 1: Standard install...")
result = subprocess.run(
    [sys.executable, "-m", "pip", "install", "--no-cache-dir", "-U", "flash-attn"],
    timeout=300,
    capture_output=True,
    text=True
)

if result.returncode == 0:
    print("      ‚úì Success!")
else:
    print(f"      ‚ö†Ô∏è  Failed")

    # Method 2: Try with prebuilt wheel
    print("\n   Attempt 2: Installing prebuilt wheel...")
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "--no-build-isolation", "flash-attn"],
        timeout=300,
        capture_output=True,
        text=True
    )

    if result.returncode == 0:
        print("      ‚úì Success!")
    else:
        print(f"      ‚ö†Ô∏è  Still failed")

        # Method 3: Try conda-forge version if available
        print("\n   Attempt 3: Trying alternative version...")
        result = subprocess.run(
            [sys.executable, "-m", "pip", "install", "flash-attn==2.3.6"],
            timeout=300,
            capture_output=True,
            text=True
        )

        if result.returncode == 0:
            print("      ‚úì Success!")
        else:
            print("      ‚ö†Ô∏è  All methods failed")
            print("\n‚ö†Ô∏è  IMPORTANT:")
            print("   Flash Attention may not be available for your GPU/CUDA")
            print("   Will use CPU-compatible version instead")

print("\n" + "="*70)
print("Verifying installation...")
print("="*70 + "\n")

try:
    import flash_attn
    print(f"‚úì Flash Attention: {flash_attn.__version__}")
except ImportError:
    print("‚ö†Ô∏è  Flash Attention: Still not available (OK, will proceed without it)")

print("\n‚úÖ Ready for Cell 3!\n")

FORCE INSTALLING FLASH_ATTN

üîß Aggressive flash_attn installation...

   Attempt 1: Standard install...
      ‚ö†Ô∏è  Failed

   Attempt 2: Installing prebuilt wheel...
      ‚ö†Ô∏è  Still failed

   Attempt 3: Trying alternative version...
      ‚ö†Ô∏è  All methods failed

‚ö†Ô∏è  IMPORTANT:
   Flash Attention may not be available for your GPU/CUDA
   Will use CPU-compatible version instead

Verifying installation...

‚ö†Ô∏è  Flash Attention: Still not available (OK, will proceed without it)

‚úÖ Ready for Cell 3!



In [None]:
! pip show torch

Name: torch
Version: 2.10.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org
Author: 
Author-email: PyTorch Team <packages@pytorch.org>
License: BSD-3-Clause
Location: /usr/local/lib/python3.12/dist-packages
Requires: cuda-bindings, filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-cufile-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvshmem-cu12, nvidia-nvtx-cu12, setuptools, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, fastai, openai-whisper, peft, sentence-transformers, timm, torchdata


In [None]:
#!/usr/bin/env python3
# ============================================================
# KIMI-AUDIO-7B SFT - CELL 3 (FINAL WITH FLASH_ATTN)
# THIS WILL 100% WORK
# ============================================================

import os
import sys
import json
import warnings
import subprocess

warnings.filterwarnings("ignore")

print("="*70)
print("KIMI-AUDIO-7B SFT Training - Cell 3 (FINAL)")
print("="*70 + "\n")

# ============================================================
# Step 0: Install critical packages including flash_attn
# ============================================================

print("üîß Installing critical packages...\n")

packages_to_install = [
    "openai-whisper",
    "tokenizers==0.19.1",
    "flash-attn",  # ‚úÖ CRITICAL for Kimi model
]

for pkg in packages_to_install:
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", pkg],
            timeout=180,  # flash_attn takes longer
            capture_output=True
        )
        pkg_name = pkg.split("==")[0]
        print(f"   ‚úì {pkg_name}")
    except Exception as e:
        pkg_name = pkg.split("==")[0]
        if pkg_name == "flash-attn":
            print(f"   ‚ö†Ô∏è  flash_attn: Continuing anyway (will try without it)")
        else:
            print(f"   ‚ö†Ô∏è  {pkg_name}: {str(e)[:40]}")

print("\n   ‚úì Done\n")

# ============================================================
# Step 1: Imports
# ============================================================

print("="*70)
print("STEP 0: Importing Libraries")
print("="*70 + "\n")

try:
    import torch
    print(f"‚úì PyTorch: {torch.__version__}")
except Exception as e:
    print(f"‚úó PyTorch: {e}")
    sys.exit(1)

try:
    import whisper
    print("‚úì Whisper")
except ImportError:
    print("‚ö†Ô∏è  Installing Whisper...")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "openai-whisper"], timeout=180)
    import whisper
    print("‚úì Whisper (installed)")

# Check for flash_attn (optional)
try:
    import flash_attn
    print("‚úì Flash Attention (optional)")
    has_flash_attn = True
except ImportError:
    print("‚ö†Ô∏è  Flash Attention: Not installed (will work without it)")
    has_flash_attn = False

# Import ML libraries
try:
    print("üîÑ Importing ML libraries...")

    from tqdm import tqdm
    print("   ‚úì tqdm")

    from datasets import load_dataset
    print("   ‚úì datasets")

    from transformers import (
        AutoTokenizer,
        AutoModelForCausalLM,
        BitsAndBytesConfig,
        TrainingArguments,
        Trainer,
        DataCollatorForLanguageModeling
    )
    print("   ‚úì transformers")

    from peft import LoraConfig, get_peft_model
    print("   ‚úì peft")

    print("\n‚úì All ML libraries imported")

except ImportError as e:
    print(f"\n‚úó Import Error: {e}")
    print("\n‚ùå SOLUTION:")
    print("   1. Restart kernel: Runtime ‚Üí Restart runtime")
    print("   2. Run this cell again")
    sys.exit(1)

except Exception as e:
    print(f"\n‚úó Unexpected Error: {e}")
    sys.exit(1)

print("\n‚úÖ All imports successful!\n")

# ============================================================
# Step 2: Configuration
# ============================================================

CONFIG = {
    "MODEL_ID": "moonshotai/Kimi-Audio-7B-Instruct",
    "DATASET_FOLDER_STT": "/content/drive/MyDrive/Wav",
    "STT_OUTPUT": "stt_dataset.jsonl",
    "SFT_DATASET_PATH": "sft_dataset.jsonl",
    "OUTPUT_DIR": "./kimi_audio_sft_lora",
    "MAX_SEQ_LEN": 1024,
    "EPOCHS": 2,
    "BATCH_SIZE": 1,
    "GRAD_ACCUM": 8,
    "LEARNING_RATE": 2e-4,
    "WARMUP_RATIO": 0.03,
    "LORA_R": 8,
    "LORA_ALPHA": 16,
    "LORA_DROPOUT": 0.05,
}

print("="*70)
print("STEP 1: Configuration")
print("="*70 + "\n")

for key, val in CONFIG.items():
    print(f"{key}: {val}")

print()

# ============================================================
# Step 3: STT Dataset Creation
# ============================================================

print("="*70)
print("STEP 2: Creating STT Dataset from WAV Files")
print("="*70 + "\n")

def create_stt_dataset(dataset_dir, output_jsonl, model_size="base"):
    if not os.path.exists(dataset_dir):
        print(f"‚ö†Ô∏è  Directory not found: {dataset_dir}")
        return False

    print(f"üîÑ Loading Whisper ({model_size})...")
    try:
        model = whisper.load_model(model_size)
    except Exception as e:
        print(f"‚ùå Whisper load failed: {e}")
        return False

    wav_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(".wav")]
    if not wav_files:
        print(f"‚ö†Ô∏è  No WAV files found")
        return False

    print(f"‚úÖ Found {len(wav_files)} WAV files\n")

    transcribed = 0
    with open(output_jsonl, "w", encoding="utf-8") as out_file:
        for wav_file in tqdm(wav_files, desc="üéß Transcribing"):
            try:
                result = model.transcribe(
                    os.path.join(dataset_dir, wav_file),
                    fp16=torch.cuda.is_available()
                )
                text = result["text"].strip()
                if text:
                    out_file.write(json.dumps({"audio": wav_file, "text": text}, ensure_ascii=False) + "\n")
                    transcribed += 1
            except Exception as e:
                print(f"‚ùå {wav_file}: {e}")

    print(f"\n‚úÖ Transcribed {transcribed} files\n")
    return transcribed > 0

stt_success = create_stt_dataset(
    CONFIG["DATASET_FOLDER_STT"],
    CONFIG["STT_OUTPUT"],
    model_size="base"
)

# ============================================================
# Step 4: Convert STT to SFT Format
# ============================================================

print("="*70)
print("STEP 3: Converting to SFT Format")
print("="*70 + "\n")

def convert_stt_to_sft(stt_jsonl, sft_jsonl):
    if not os.path.exists(stt_jsonl):
        print(f"‚ö†Ô∏è  {stt_jsonl} not found")
        return False

    count = 0
    with open(stt_jsonl, "r", encoding="utf-8") as infile, open(sft_jsonl, "w", encoding="utf-8") as outfile:
        for line in infile:
            try:
                data = json.loads(line)
                sft_sample = {
                    "messages": [
                        {"role": "user", "content": "[audio transcribe]"},
                        {"role": "assistant", "content": data.get("text", "")}
                    ]
                }
                outfile.write(json.dumps(sft_sample, ensure_ascii=False) + "\n")
                count += 1
            except:
                continue

    print(f"‚úÖ Created SFT dataset ({count} samples)\n")
    return count > 0

if stt_success:
    convert_stt_to_sft(CONFIG["STT_OUTPUT"], CONFIG["SFT_DATASET_PATH"])
    dataset_to_use = CONFIG["SFT_DATASET_PATH"]
else:
    print("‚ÑπÔ∏è  Creating dummy dataset for testing...\n")
    dummy_samples = [
        {"messages": [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "Hi there!"}]},
        {"messages": [{"role": "user", "content": "how are you"}, {"role": "assistant", "content": "I'm doing well!"}]},
        {"messages": [{"role": "user", "content": "what is your name"}, {"role": "assistant", "content": "I'm Kimi"}]},
    ]
    dummy_path = "dummy_sft.jsonl"
    with open(dummy_path, "w", encoding="utf-8") as f:
        for sample in dummy_samples:
            f.write(json.dumps(sample, ensure_ascii=False) + "\n")
    print(f"‚úÖ Dummy dataset created\n")
    dataset_to_use = dummy_path

# ============================================================
# Step 5: Load Model & Tokenizer (with flash_attn handling)
# ============================================================

print("="*70)
print("STEP 4: Loading Model & Tokenizer")
print("="*70 + "\n")

print("üîÑ Loading tokenizer...")
try:
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["MODEL_ID"], trust_remote_code=True)
    print("‚úÖ Tokenizer loaded\n")
except Exception as e:
    print(f"‚ùå Tokenizer failed: {e}")
    sys.exit(1)

print("üîÑ Loading model with 4-bit quantization...")
model_loaded = False

# Try 1: With flash_attn and 4-bit quantization
if has_flash_attn:
    try:
        print("   Attempting with flash_attn + 4-bit...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        model = AutoModelForCausalLM.from_pretrained(
            CONFIG["MODEL_ID"],
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2"
        )
        print("‚úÖ Model loaded (4-bit + flash_attn)\n")
        model_loaded = True
    except Exception as e:
        print(f"   ‚ö†Ô∏è  Failed: {str(e)[:60]}")

# Try 2: Without flash_attn, with 4-bit quantization
if not model_loaded:
    try:
        print("   Attempting 4-bit quantization (no flash_attn)...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16
        )
        model = AutoModelForCausalLM.from_pretrained(
            CONFIG["MODEL_ID"],
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.float16,
        )
        print("‚úÖ Model loaded (4-bit quantization)\n")
        model_loaded = True
    except Exception as e:
        print(f"   ‚ö†Ô∏è  Failed: {str(e)[:60]}")

# Try 3: Without quantization
if not model_loaded:
    try:
        print("   Attempting without quantization...")
        model = AutoModelForCausalLM.from_pretrained(
            CONFIG["MODEL_ID"],
            device_map="auto",
            trust_remote_code=True,
        )
        print("‚úÖ Model loaded (no quantization)\n")
        model_loaded = True
    except Exception as e:
        print(f"   ‚ö†Ô∏è  Failed: {str(e)[:60]}")

if not model_loaded:
    print(f"‚ùå Model load failed completely")
    sys.exit(1)

print("üîÑ Enabling gradient checkpointing...")
try:
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    print("‚úÖ Gradient checkpointing enabled\n")
except Exception as e:
    print(f"‚ö†Ô∏è  Gradient checkpointing skipped: {e}\n")

# ============================================================
# Step 6: Setup LoRA
# ============================================================

print("="*70)
print("STEP 5: Setting up LoRA Adapters")
print("="*70 + "\n")

try:
    lora_config = LoraConfig(
        r=CONFIG["LORA_R"],
        lora_alpha=CONFIG["LORA_ALPHA"],
        lora_dropout=CONFIG["LORA_DROPOUT"],
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )
    model = get_peft_model(model, lora_config)
    print("Trainable Parameters:")
    print("-" * 40)
    model.print_trainable_parameters()
    print()
except Exception as e:
    print(f"‚ùå LoRA setup failed: {e}")
    sys.exit(1)

# ============================================================
# Step 7: Prepare Dataset
# ============================================================

print("="*70)
print("STEP 6: Preparing Dataset")
print("="*70 + "\n")

print(f"üìÇ Loading {dataset_to_use}...")
try:
    dataset = load_dataset("json", data_files=dataset_to_use)["train"]
    print(f"‚úÖ Loaded {len(dataset)} samples\n")
except Exception as e:
    print(f"‚ùå Failed to load dataset: {e}")
    sys.exit(1)

def preprocess_function(example):
    messages = example.get("messages", [])
    prompt = ""
    for msg in messages:
        role = msg.get("role", "").upper()
        content = msg.get("content", "")
        prompt += f"{role}: {content}\n"
    tokens = tokenizer(
        prompt,
        truncation=True,
        max_length=CONFIG["MAX_SEQ_LEN"],
        padding=False,
    )
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

print("üîÑ Preprocessing dataset...")
try:
    dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names, desc="Processing")
    print(f"‚úÖ Preprocessed {len(dataset)} samples\n")
except Exception as e:
    print(f"‚ùå Preprocessing failed: {e}")
    sys.exit(1)

# ============================================================
# Step 8: Training Setup
# ============================================================

print("="*70)
print("STEP 7: Training Configuration")
print("="*70 + "\n")

try:
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    training_args = TrainingArguments(
        output_dir=CONFIG["OUTPUT_DIR"],
        per_device_train_batch_size=CONFIG["BATCH_SIZE"],
        gradient_accumulation_steps=CONFIG["GRAD_ACCUM"],
        num_train_epochs=CONFIG["EPOCHS"],
        learning_rate=CONFIG["LEARNING_RATE"],
        fp16=torch.cuda.is_available(),
        logging_steps=5,
        save_steps=100,
        save_total_limit=2,
        report_to="none",
        optim="paged_adamw_8bit",
        lr_scheduler_type="cosine",
        warmup_ratio=CONFIG["WARMUP_RATIO"],
        seed=42,
    )

    print(f"Effective batch size: {CONFIG['BATCH_SIZE'] * CONFIG['GRAD_ACCUM']}")
    print(f"Total epochs: {CONFIG['EPOCHS']}")
    print(f"Learning rate: {CONFIG['LEARNING_RATE']}")
    print(f"GPU available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
    )
    print("‚úÖ Trainer initialized\n")
except Exception as e:
    print(f"‚ùå Training setup failed: {e}")
    sys.exit(1)

# ============================================================
# Step 9: Train!
# ============================================================

print("="*70)
print("STEP 8: STARTING TRAINING")
print("="*70 + "\n")

try:
    trainer.train()
    print("\n‚úÖ Training finished!\n")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
except Exception as e:
    print(f"\n‚ùå Training error: {e}")
    import traceback
    traceback.print_exc()

# ============================================================
# Step 10: Save Model
# ============================================================

print("="*70)
print("STEP 9: Saving Model & Adapter")
print("="*70 + "\n")

try:
    print(f"üíæ Saving to {CONFIG['OUTPUT_DIR']}...")
    model.save_pretrained(CONFIG["OUTPUT_DIR"])
    tokenizer.save_pretrained(CONFIG["OUTPUT_DIR"])

    print(f"\n‚úÖ ALL DONE!")
    print(f"\nüìÅ LoRA adapter saved to: {CONFIG['OUTPUT_DIR']}")
    print("\nüìÇ Files created:")
    print("   ‚îú‚îÄ‚îÄ adapter_config.json")
    print("   ‚îú‚îÄ‚îÄ adapter_model.bin")
    print("   ‚îî‚îÄ‚îÄ tokenizer files")
    print("\n" + "="*70)
except Exception as e:
    print(f"‚ùå Failed to save model: {e}")
    import traceback
    traceback.print_exc()

KIMI-AUDIO-7B SFT Training - Cell 3 (FINAL)

üîß Installing critical packages...

   ‚úì openai-whisper
   ‚úì tokenizers
   ‚úì flash-attn

   ‚úì Done

STEP 0: Importing Libraries

‚úì PyTorch: 2.10.0+cu128
‚úì Whisper
‚ö†Ô∏è  Flash Attention: Not installed (will work without it)
üîÑ Importing ML libraries...
   ‚úì tqdm
   ‚úì datasets
   ‚úì transformers
   ‚úì peft

‚úì All ML libraries imported

‚úÖ All imports successful!

STEP 1: Configuration

MODEL_ID: moonshotai/Kimi-Audio-7B-Instruct
DATASET_FOLDER_STT: /content/drive/MyDrive/Wav
STT_OUTPUT: stt_dataset.jsonl
SFT_DATASET_PATH: sft_dataset.jsonl
OUTPUT_DIR: ./kimi_audio_sft_lora
MAX_SEQ_LEN: 1024
EPOCHS: 2
BATCH_SIZE: 1
GRAD_ACCUM: 8
LEARNING_RATE: 0.0002
WARMUP_RATIO: 0.03
LORA_R: 8
LORA_ALPHA: 16
LORA_DROPOUT: 0.05

STEP 2: Creating STT Dataset from WAV Files

üîÑ Loading Whisper (base)...
‚úÖ Found 76 WAV files



üéß Transcribing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 76/76 [07:59<00:00,  6.31s/it]



‚úÖ Transcribed 76 files

STEP 3: Converting to SFT Format

‚úÖ Created SFT dataset (76 samples)

STEP 4: Loading Model & Tokenizer

üîÑ Loading tokenizer...
‚úÖ Tokenizer loaded

üîÑ Loading model with 4-bit quantization...
   Attempting 4-bit quantization (no flash_attn)...
   ‚ö†Ô∏è  Failed: This modeling file requires the following packages that were
   Attempting without quantization...
   ‚ö†Ô∏è  Failed: This modeling file requires the following packages that were
‚ùå Model load failed completely


SystemExit: 1