# 🌧️ Timer-XL Peru Rainfall Prediction - Google Colab

This notebook demonstrates the complete pipeline for training Timer-XL on Peru rainfall data.

**Steps:**
1. Setup environment
2. Upload ERA5 data
3. Preprocess data
4. Train Timer-XL with transfer learning
5. Evaluate results

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone repository
!git clone https://github.com/ChristianPE1/test-openltm-code.git
%cd test-openltm-code

In [None]:
# Install dependencies
!pip install -r requirements.txt

In [None]:
# Mount Google Drive (to download checkpoint.pth and save training results)
from google.colab import drive
drive.mount('/content/drive')

print("✅ Google Drive mounted")

## 2. Verificar Datos ERA5

**Los archivos .nc ya están en el repositorio** (datasets/raw_era5/)  
Solo necesitas verificar que se clonaron correctamente.

In [None]:
# Verify ERA5 files are in the repository
!ls -lh datasets/raw_era5/


## 3. Preprocess Data

In [None]:
# Run preprocessing script
# ⚠️ IMPORTANT: ERA5 precipitation is in METERS, not millimeters!
# Use threshold in METERS: 0.1 mm = 0.0001 m

!python preprocessing/preprocess_era5_peru.py \
    --input_dir datasets/raw_era5 \
    --output_dir datasets/processed \
    --years 2022,2023,2024 \
    --target_horizon 24 \
    --threshold 0.0001

print("\n✅ Preprocessing complete!")
print("📊 Output files saved to: datasets/processed/")
print("💡 Threshold: 0.0001 m = 0.1 mm")

In [None]:
# Load processed data for quick inspection
import pandas as pd
import json

df = pd.read_csv('datasets/processed/peru_rainfall.csv')
print(f"Dataset shape: {df.shape}")
print(f"\nFirst few rows:")
print(df.head())

# Load statistics
with open('datasets/processed/preprocessing_stats.json') as f:
    stats = json.load(f)
print(f"\nStatistics:")
print(json.dumps(stats, indent=2))

## 🚨 CRITICAL: Verify Class Balance

**Before training, we MUST check that both classes exist!**

In [None]:
# CRITICAL: Check class distribution
import pandas as pd
import numpy as np

df = pd.read_csv('datasets/processed/peru_rainfall.csv')

print("📊 Class Distribution Analysis:")
print(f"   Total samples: {len(df)}")
print(f"   rain_24h column:")
print(df['rain_24h'].value_counts())
print(f"\n   Percentage:")
print(df['rain_24h'].value_counts(normalize=True) * 100)

# Check precipitation values
print(f"\n🌧️ Precipitation Statistics (in METERS from ERA5):")
print(f"   Min: {df['precipitation'].min():.6f} m = {df['precipitation'].min()*1000:.3f} mm")
print(f"   Max: {df['precipitation'].max():.6f} m = {df['precipitation'].max()*1000:.3f} mm")
print(f"   Mean: {df['precipitation'].mean():.6f} m = {df['precipitation'].mean()*1000:.3f} mm")
print(f"   Median: {df['precipitation'].median():.6f} m = {df['precipitation'].median()*1000:.3f} mm")
print(f"   95th percentile: {df['precipitation'].quantile(0.95):.6f} m = {df['precipitation'].quantile(0.95)*1000:.3f} mm")

# ⚠️ IMPORTANT: ERA5 precipitation is in METERS, not millimeters!
# Threshold must be in meters too
threshold_mm = 0.1  # Target in mm
threshold_m = threshold_mm / 1000.0  # Convert to meters

samples_above_threshold = (df['precipitation'] >= threshold_m).sum()
print(f"\n⚠️  Samples with precipitation >= {threshold_mm} mm ({threshold_m:.6f} m): {samples_above_threshold} ({samples_above_threshold/len(df)*100:.2f}%)")

if samples_above_threshold < len(df) * 0.1:
    print(f"\n⚠️ Class imbalance detected!")
    print(f"   Only {samples_above_threshold} rain events ({samples_above_threshold/len(df)*100:.1f}%)")
    print(f"\n💡 SOLUTION:")
    # Calculate threshold for 30-35% rain events
    suggested_threshold_m = df['precipitation'].quantile(0.65)
    suggested_threshold_mm = suggested_threshold_m * 1000
    print(f"   Suggested threshold for 35% rain events:")
    print(f"   - In meters: {suggested_threshold_m:.6f} m")
    print(f"   - In mm: {suggested_threshold_mm:.4f} mm")
else:
    print(f"\n✅ Good class balance: {samples_above_threshold/len(df)*100:.1f}% rain events")

## 🔧 Optional: Re-preprocess with Adjusted Threshold

**Run this ONLY if the class distribution check above shows imbalanced data (< 10% rain events)**

In [None]:
# Re-run preprocessing with adjusted threshold for better class balance
# This creates a more balanced dataset by adjusting the rain threshold

# Calculate appropriate threshold (aiming for ~30-40% rain events)
df_temp = pd.read_csv('datasets/processed/peru_rainfall.csv')

# ERA5 precipitation is in METERS
suggested_threshold_m = df_temp['precipitation'].quantile(0.65)  # 35% will be "rain"
suggested_threshold_mm = suggested_threshold_m * 1000  # Convert to mm for display

print(f"🎯 Suggested threshold:")
print(f"   {suggested_threshold_m:.6f} m = {suggested_threshold_mm:.4f} mm")
print(f"   This should give ~35% rain events\n")

# Re-run preprocessing with threshold in METERS
!python preprocessing/preprocess_era5_peru.py \
    --input_dir datasets/raw_era5 \
    --output_dir datasets/processed \
    --years 2022,2023,2024 \
    --target_horizon 24 \
    --threshold {suggested_threshold_m:.6f}

print(f"\n✅ Data re-processed with adjusted threshold!")
print(f"💡 Used: {suggested_threshold_m:.6f} m ({suggested_threshold_mm:.4f} mm)")
print("📊 Now check class distribution again...")

## 4. Train Timer-XL

In [None]:
# Copy pre-trained checkpoint from Google Drive
import os

checkpoint_dir = 'checkpoints/timer_xl'
checkpoint_path = f'{checkpoint_dir}/checkpoint.pth'


!mkdir -p checkpoints/timer_xl/

!cp '/content/drive/MyDrive/timer_xl_peru/checkpoints/checkpoint.pth' \
    checkpoints/timer_xl/

In [None]:
# Set PyTorch memory configuration to reduce fragmentation
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Train Timer-XL with transfer learning
# This will take 4-6 hours on T4 GPU

# NOTE: If you get OOM (Out of Memory) error, reduce batch_size:
# --batch_size 16  (for T4 with 16GB VRAM - current setting)
# --batch_size 8   (if still OOM)

!python run.py \
  --task_name classification \
  --is_training 1 \
  --model_id peru_rainfall_timerxl \
  --model timer_xl_classifier \
  --data PeruRainfall \
  --root_path datasets/processed/ \
  --data_path peru_rainfall.csv \
  --checkpoints checkpoints/ \
  --seq_len 1440 \
  --input_token_len 96 \
  --output_token_len 96 \
  --test_seq_len 1440 \
  --test_pred_len 2 \
  --e_layers 8 \
  --d_model 1024 \
  --d_ff 2048 \
  --n_heads 8 \
  --dropout 0.1 \
  --activation relu \
  --batch_size 16 \
  --learning_rate 1e-5 \
  --train_epochs 50 \
  --patience 10 \
  --n_classes 2 \
  --gpu 0 \
  --cosine \
  --tmax 50 \
  --adaptation \
  --pretrain_model_path checkpoints/timer_xl/checkpoint.pth \
  --use_focal_loss \
  --loss CE \
  --itr 1 \
  --des 'Peru_Rainfall_Transfer_Learning'

print("\n✅ Training complete!")
print("📊 Results saved to: checkpoints/peru_rainfall_timerxl/")

## 5. Save Checkpoint to Drive

Prevent losing your trained model!

In [None]:
# Copy training results to Google Drive (prevent losing trained model!)
import shutil
import os
import glob

# Find the checkpoint directory
checkpoint_base = 'checkpoints'
results_pattern = f'{checkpoint_base}/*/peru_rainfall_timerxl*/'

matching_dirs = glob.glob(results_pattern)

if matching_dirs:
    results_path = matching_dirs[0]
    
    # Copy entire results folder to Drive
    drive_results = '/content/drive/MyDrive/timer_xl_peru/results/'
    os.makedirs(drive_results, exist_ok=True)
    
    print("💾 Copying results to Google Drive...")
    print(f"   From: {results_path}")
    print(f"   To: {drive_results}")
    
    # Use shutil for better error handling
    try:
        shutil.copytree(results_path, os.path.join(drive_results, os.path.basename(results_path.rstrip('/'))), dirs_exist_ok=True)
        print("✅ Checkpoint and results saved to Google Drive!")
        print(f"📁 Location: {drive_results}")
    except Exception as e:
        print(f"⚠️ Error copying to Drive: {e}")
        print("   You can manually copy from:", results_path)
else:
    print("⚠️ No results found. Training may have failed or is still in progress.")
    print("   Expected pattern:", results_pattern)

## 6. Quick Evaluation

In [None]:
# Load and display test results
import os
import json
import glob

# Find results directory
checkpoint_base = 'checkpoints'
results_pattern = f'{checkpoint_base}/*/peru_rainfall_timerxl*/'
matching_dirs = glob.glob(results_pattern)

if matching_dirs:
    results_dir = matching_dirs[0]
    print(f"📂 Results directory: {results_dir}\n")
    
    # List all files
    print("📄 Files in results:")
    for file in os.listdir(results_dir):
        print(f"   - {file}")
    
    # Try to load metrics
    metrics_files = glob.glob(os.path.join(results_dir, '*metrics*.json'))
    
    if metrics_files:
        print(f"\n📊 Loading metrics from: {metrics_files[0]}")
        with open(metrics_files[0]) as f:
            metrics = json.load(f)
        print("\n✅ Test Metrics:")
        print(json.dumps(metrics, indent=2))
    else:
        print("\n⚠️ No metrics file found yet. Training may still be in progress.")
else:
    print("⚠️ No results directory found. Training may have failed.")
    print(f"   Expected pattern: {results_pattern}")

## 🎉 Training Complete!

**Next steps:**
1. Download results from `results/peru_rainfall/`
2. Analyze confusion matrix and classification report
3. Try different context lengths (seq_len)
4. Experiment with different hyperparameters