# Complete Guide to train_network

This notebook provides comprehensive documentation for the `train_network` function, including the new robustness features that protect against server disconnects.

## What's New? 🎉

Three new parameters make your training robust against CoCalc disconnects:
- **`total_epochs`** - Train to a specific epoch number (not additional epochs)
- **`save_last`** - Save checkpoint after every epoch
- **`resume_last`** - Automatically resume from where you left off

**The Big Idea:** If your server disconnects, just restart the kernel and re-run all cells. Training continues automatically!

## Quick Start: Making Training Robust

### Before (Not Robust):
```python
results = train_network(
    model, loss_func, train_loader, test_loader,
    epochs=50,
    checkpoint_file='model.pt',
    device=device
)
```
❌ If interrupted at epoch 30, you lose all progress!

### After (Robust):
```python
results = train_network(
    model, loss_func, train_loader, test_loader,
    total_epochs=50,      # ← Train TO epoch 50 (not 50 more)
    save_last=True,       # ← Save after every epoch
    resume_last=True,     # ← Auto-resume from last save
    checkpoint_file='model.pt',
    device=device
)
```
✅ If interrupted at epoch 30, automatically resumes and continues to epoch 50!

## The Three New Parameters Explained

### 1. `total_epochs` (int, optional)
**Purpose:** Sets the target epoch number to reach (absolute), not how many epochs to run (relative)

- `total_epochs=50` means "train until you reach epoch 50"
- If starting fresh: trains 50 epochs
- If resuming from epoch 30: trains 20 more epochs to reach 50
- If already at epoch 50+: returns immediately with message

**When to use:** Always use this for homework to ensure consistent training

### 2. `save_last` (bool, default=False)
**Purpose:** Saves a checkpoint after every epoch (not just the best model)

- Creates `model_last.pt` alongside `model.pt`
- `model.pt` = best model so far
- `model_last.pt` = most recent epoch

**When to use:** Set to `True` when you want to resume from interruptions

### 3. `resume_last` (bool, default=False)
**Purpose:** Automatically loads the most recent checkpoint when starting

- Looks for `model_last.pt` and loads it if found
- Shows "Resuming from epoch X" message
- Progress bar shows correct position (e.g., 30/50)
- **If checkpoint doesn't exist:** Starts training from scratch (no error)

**When to use:** Set to `True` when you want automatic recovery

**Why this design?** Setting `resume_last=True` allows the same code to work for:
- First training run (starts fresh when no checkpoint exists)
- Resuming after disconnect (continues from checkpoint)
- No need to change parameters between initial run and resumption!

## Common Usage Patterns

### Pattern 1: Standard Homework Training (Recommended)
```python
# Robust training for homework
results = train_network(
    model=model,
    loss_func=nn.CrossEntropyLoss(),
    train_loader=train_loader,
    val_loader=val_loader,
    score_funcs={'ACC': accuracy_score},

    # Robustness parameters
    total_epochs=50,      # Train to epoch 50
    save_last=True,       # Save every epoch
    resume_last=True,     # Auto-resume

    # Standard parameters
    checkpoint_file=MODELS_PATH / 'homework_model.pt',
    device=device
)
```

### Pattern 2: Model Iteration - Tweak and Retrain (Best Practice)
```python
# BEST PRACTICE: Use different checkpoint names for each attempt
# You can ALWAYS keep resume_last=True - it won't load old checkpoints!

# First attempt
model_v1 = SimpleNet()  # train_network moves to device automatically
results_v1 = train_network(
    model_v1, loss_func, train_loader, val_loader,
    total_epochs=30,
    save_last=True,
    resume_last=True,     # ← Keep this True! No problem!
    checkpoint_file=MODELS_PATH / 'attempt1.pt',  # ← Unique name
    score_funcs={'ACC': accuracy_score},
    device=device
)
print(f"Attempt 1: {results_v1['val ACC'].max():.4f}")

# Second attempt with improved model
model_v2 = ImprovedNet()  # No need for .to(device)
results_v2 = train_network(
    model_v2, loss_func, train_loader, val_loader,
    total_epochs=30,
    save_last=True,
    resume_last=True,     # ← Still True! attempt2_last.pt doesn't exist yet
    checkpoint_file=MODELS_PATH / 'attempt2.pt',  # ← Different name = fresh start
    score_funcs={'ACC': accuracy_score},
    device=device
)
print(f"Attempt 2: {results_v2['val ACC'].max():.4f}")

# Third attempt with different hyperparameters
model_v3 = ImprovedNet()
optimizer_v3 = torch.optim.AdamW(model_v3.parameters(), lr=0.0001)
results_v3 = train_network(
    model_v3, loss_func, train_loader, val_loader,
    optimizer=optimizer_v3,
    total_epochs=50,
    save_last=True,
    resume_last=True,     # ← Always True for robustness!
    checkpoint_file=MODELS_PATH / 'attempt3.pt',  # ← New name again
    score_funcs={'ACC': accuracy_score},
    device=device
)
print(f"Attempt 3: {results_v3['val ACC'].max():.4f}")
```

**✅ Best Practice Summary:**
- Use numbered checkpoint names: `attempt1.pt`, `attempt2.pt`, `attempt3.pt`
- Keep `resume_last=True` always - it only resumes if `attemptX_last.pt` exists
- Each new checkpoint name automatically means fresh training
- No need to toggle `resume_last` or delete files between attempts!

### Pattern 3: Reusing Same Checkpoint Name (Alternative)
```python
# If you prefer to reuse the same checkpoint name, delete old files first
import shutil
import os

# Clean up previous attempt
if os.path.exists(MODELS_PATH / 'model.pt'):
    os.remove(MODELS_PATH / 'model.pt')
if os.path.exists(MODELS_PATH / 'model_last.pt'):
    os.remove(MODELS_PATH / 'model_last.pt')

# Now train fresh with the same name
model = NewModel()  # train_network handles device placement
results = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=50,
    save_last=True,
    resume_last=True,  # Safe because we deleted the old files
    checkpoint_file=MODELS_PATH / 'model.pt',
    device=device
)
```

### Pattern 4: Quick Experiments
```python
# Traditional approach for quick tests
results = train_network(
    model=model,
    loss_func=loss_func,
    train_loader=train_loader,
    test_loader=test_loader,

    epochs=10,  # Just run 10 epochs
    device=device
)
```

### Pattern 5: Fine-tuning
```python
# Resume from best model and continue training
results = train_network(
    model=model,
    loss_func=loss_func,
    train_loader=train_loader,
    val_loader=val_loader,

    resume_checkpoint=True,  # Start from best model
    total_epochs=100,        # Train to epoch 100
    save_last=True,
    resume_last=True,
    checkpoint_file=MODELS_PATH / 'model.pt',
    device=device
)
```

### Pattern 6: With Early Stopping (Optional)
```python
# Training with early stopping to prevent overfitting
results = train_network(
    model=model,
    loss_func=nn.CrossEntropyLoss(),
    train_loader=train_loader,
    val_loader=val_loader,
    score_funcs={'ACC': accuracy_score},

    # Robustness parameters
    total_epochs=100,     # Maximum epochs
    save_last=True,
    resume_last=True,

    # Early stopping parameters
    early_stop_metric='val ACC',  # Monitor validation accuracy
    early_stop_crit='max',         # Stop when ACC stops increasing
    patience=5,                    # Wait 5 epochs before stopping

    checkpoint_file=MODELS_PATH / 'model.pt',
    device=device
)
```

## Parameter Quick Reference

### Essential Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `model` | nn.Module | required | Neural network to train |
| `loss_func` | callable | required | Loss function (e.g., nn.CrossEntropyLoss()) |
| `train_loader` | DataLoader | required | Training data loader |
| `device` | str | "cpu" | Device to train on ("cuda" or "cpu") |

### Data Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `val_loader` | DataLoader | None | Validation data loader |
| `test_loader` | DataLoader | None | Test data loader |
| `score_funcs` | dict | None | Metrics to track {'name': function} |

### Training Control
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `epochs` | int | 50 | Number of epochs to train (relative) |
| **`total_epochs`** | int | None | **NEW: Total epochs to reach (absolute)** |
| `optimizer` | Optimizer | None | Custom optimizer (default: AdamW) |
| `lr_schedule` | dict | None | Learning rate scheduler config |
| `grad_clip` | float | None | Gradient clipping value |

### Checkpointing
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `checkpoint_file` | str | None | Path to save best model |
| **`save_last`** | bool | False | **NEW: Save after every epoch** |
| **`resume_last`** | bool | False | **NEW: Resume from last checkpoint** |
| `resume_checkpoint` | bool | False | Resume from best checkpoint |
| `resume_file` | str | None | Resume from specific file |

### Early Stopping
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `early_stop_metric` | str | None | Metric to monitor (e.g., "val loss") |
| `early_stop_crit` | str | "min" | "min" or "max" |
| `patience` | int | 4 | Epochs without improvement before stopping |

## Understanding epochs vs total_epochs

### The Key Difference:

| Parameter | Behavior | Starting Fresh | Resuming from Epoch 30 |
|-----------|----------|----------------|------------------------|
| `epochs=20` | Train 20 MORE epochs | Trains epochs 1-20 | Trains epochs 31-50 |
| `total_epochs=20` | Train TO epoch 20 | Trains epochs 1-20 | Already past 20, stops immediately |

### Visual Example:
```
With epochs=10:
  Fresh start:    [1 =====> 10]
  Resume from 5:  [6 =====> 15]
  Resume from 20: [21 ====> 30]

With total_epochs=10:
  Fresh start:    [1 =====> 10]
  Resume from 5:  [6 =====> 10]
  Resume from 20: [Already complete!]
```

## Messages You'll See

### Starting Fresh:
```
Epoch:   0%|          | 0/50 [00:00<?, ?it/s, epoch=1, train_loss=2.302]
```

### Resuming from Checkpoint:
```
Resuming from last checkpoint: model_last.pt
Resuming from epoch 30
Epoch:  60%|██████    | 30/50 [00:00<?, ?it/s, epoch=31, train_loss=0.045]
```

### Training Already Complete:
```
Resuming from last checkpoint: model_last.pt
Resuming from epoch 50
Training already complete for 50 epochs.
Current epoch: 50. Increase total_epochs to continue training.
```

### After Training Completes:
```
Epoch: 100%|██████████| 50/50 [05:23<00:00, 6.47s/it, epoch=50, train_loss=0.015]

Best model saved at epoch 42 (val loss: 0.0234)
```

## Files Created

When using `checkpoint_file='model.pt'`:

### Without `save_last`:
- **`model.pt`** - Best model only (updated when validation improves)

### With `save_last=True`:
- **`model.pt`** - Best model (lowest validation loss or best metric)
- **`model_last.pt`** - Most recent checkpoint (updated every epoch)

### Example After Training:
```
models/
├── model.pt         # Best model from epoch 42 (245 KB)
└── model_last.pt    # Latest model from epoch 50 (245 KB)
```

## Complete Example with Real Code

In [None]:
# This is a complete working example
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from introdl.idlmam import train_network
from introdl.utils import get_device

# Setup
device = get_device()
MODELS_PATH = Path('models')
MODELS_PATH.mkdir(exist_ok=True)

# Create dummy data for demonstration
X_train = torch.randn(100, 10)
y_train = torch.randint(0, 2, (100,))
X_val = torch.randn(20, 10)
y_val = torch.randint(0, 2, (20,))

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=16)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=16)

# Simple model - no need for .to(device)
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 2)
)  # train_network will handle device placement

print("Ready to train with robustness features!")

In [None]:
# Train with all robustness features
results = train_network(
    model=model,
    loss_func=nn.CrossEntropyLoss(),
    train_loader=train_loader,
    val_loader=val_loader,
    
    # NEW: Robustness parameters
    total_epochs=10,      # Train to epoch 10
    save_last=True,       # Save every epoch
    resume_last=True,     # Auto-resume if interrupted
    
    # Standard parameters
    checkpoint_file=MODELS_PATH / 'demo_model.pt',
    device=device,
    disable_tqdm=False    # Show progress bar
)

print(f"\nTraining complete! Trained {len(results)} epochs.")
print(f"Final validation loss: {results['val loss'].iloc[-1]:.4f}")

## Tips for Homework

### 1. Standard Template for Homework
```python
# Use this template for all homework training
results = train_network(
    model, loss_func, train_loader, val_loader,
    
    # Always include these for robustness
    total_epochs=50,
    save_last=True,
    resume_last=True,
    checkpoint_file=MODELS_PATH / 'hw_model.pt',
    
    # Your specific parameters
    score_funcs={'ACC': accuracy_score},
    device=device
)
```

### 2. After a Disconnect
1. Don't panic!
2. Kernel → Restart
3. Run all cells from top to bottom
4. Training automatically continues
5. You'll see "Resuming from epoch X"

### 3. Checking Your Progress
```python
# See how many epochs you've trained
print(f"Trained {len(results)} epochs")
print(f"Best accuracy: {results['val ACC'].max():.4f}")
```

### 4. Iterating on Your Model Design (Best Practice)

**🎯 KEY INSIGHT: Always keep `resume_last=True` for robustness!**  
Just use different checkpoint names for different attempts - no need to change parameters!

```python
# ✅ BEST PRACTICE: Use numbered checkpoint names

# Attempt 1: Initial model
class ModelV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = ModelV1()  # No need for .to(device)
results = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=30,
    save_last=True,
    resume_last=True,  # ← Always True! No problem!
    checkpoint_file=MODELS_PATH / 'attempt1.pt',  # ← Unique name
    device=device  # train_network handles device placement
)
print(f"Attempt 1: {results['val ACC'].max():.3f}")

# Attempt 2: Deeper model (automatic fresh start!)
class ModelV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

model = ModelV2()  # train_network moves to device
results = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=30,
    save_last=True,
    resume_last=True,  # ← Still True! attempt2_last.pt doesn't exist
    checkpoint_file=MODELS_PATH / 'attempt2.pt',  # ← Different name = fresh
    device=device
)
print(f"Attempt 2: {results['val ACC'].max():.3f}")
```

**Why this works:**
- `resume_last=True` looks for `attempt1_last.pt`, `attempt2_last.pt`, etc.
- Each new name means those files don't exist → automatic fresh start
- Still protected against disconnects during each attempt!
- `train_network` automatically moves model to specified device

**Alternative: If you want to reuse the same checkpoint name:**
```python
# Delete old checkpoints in Homework_XX_Models directory
import os
if os.path.exists(MODELS_PATH / 'model.pt'):
    os.remove(MODELS_PATH / 'model.pt')
if os.path.exists(MODELS_PATH / 'model_last.pt'):
    os.remove(MODELS_PATH / 'model_last.pt')

# Now train with reused name
model = NewModel()  # No .to(device) needed
results = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=50,
    save_last=True,
    resume_last=True,
    checkpoint_file=MODELS_PATH / 'model.pt',
    device=device
)
```

### 5. Continuing Training (Same Model)
```python
# Just increase total_epochs and run again
results_more = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=75,  # Now train to 75 instead of 50
    save_last=True,
    resume_last=True,  # ← TRUE to continue
    checkpoint_file=MODELS_PATH / 'hw_model.pt',
    device=device
)
```

### 6. Optional: Adding Early Stopping
```python
# If you want to prevent overfitting
results = train_network(
    model, loss_func, train_loader, val_loader,
    total_epochs=100,
    save_last=True,
    resume_last=True,
    checkpoint_file=MODELS_PATH / 'hw_model.pt',
    
    # Add early stopping
    early_stop_metric='val ACC',  # or 'val loss'
    early_stop_crit='max',         # 'max' for ACC, 'min' for loss
    patience=5,                    # Stop after 5 epochs without improvement
    
    score_funcs={'ACC': accuracy_score},
    device=device
)
```

## Troubleshooting

### Issue: "Training already complete for X epochs"
**Solution:** Increase `total_epochs` to continue training

### Issue: Not resuming from checkpoint
**Check:**
- Is `resume_last=True` set?
- Does the checkpoint file exist?
- Is the path correct?

### Issue: What if resume_last=True but no checkpoint exists?
**Answer:** Training starts from scratch (no error)
- If `model_last.pt` doesn't exist, training begins fresh
- This is intentional - allows same code for first run and resumption
- You'll see normal output (no "Resuming from..." message)

### Issue: Wrong checkpoint being loaded
**Priority order:**
1. `resume_last=True` loads `model_last.pt`
2. `resume_checkpoint=True` loads `model.pt`
3. `resume_file='path.pt'` loads specific file

### Issue: Want to start fresh
**Solution:** Either:
- Delete the checkpoint files, or
- Set `resume_last=False, resume_checkpoint=False`

### Issue: Conflicting resume parameters
**What happens:**
- If both `resume_last=True` and `resume_checkpoint=True`: Shows warning, uses `resume_last`
- If `resume_last=True` but file missing: Silently starts fresh
- If `resume_checkpoint=True` but file missing: Silently starts fresh

## Summary

### The Simple Recipe for Robust Training:

```python
# Just add these 3 lines to make any training robust:
total_epochs=N,      # Train to epoch N
save_last=True,      # Save progress
resume_last=True     # Auto-resume
```

### Benefits:
- ✅ **No lost work** - Automatically continues from interruptions
- ✅ **Simple** - Just 3 parameters to add
- ✅ **Clear** - Progress bar shows exactly where you are
- ✅ **Reliable** - Same results whether interrupted or not
- ✅ **Backward compatible** - Old code still works

### Remember:
- Use `total_epochs` (absolute) instead of `epochs` (relative) for homework
- After a disconnect, just re-run all cells
- The best model is always saved in `checkpoint_file`
- Check the final message to see which epoch was best

Happy training! 🚀