# Model Training
## Multi-Task Endangered Species Classifier

This notebook trains the multi-task CNN model for conservation status and geographic region prediction.

In [None]:
# Import required libraries
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from src.multi_task_model import MultiTaskModel, save_multi_task_model
from src.data_loader import create_dataloaders, load_species_data
from config.model_config import MODEL_CONFIG, TRAINING_CONFIG, MODEL_PATHS

## Setup Device and Configuration

In [None]:
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Display configuration
print("\nModel Configuration:")
print(f"  Input size: {MODEL_CONFIG['input_size']}")
print(f"  Batch size: {MODEL_CONFIG['batch_size']}")
print(f"  Epochs: {MODEL_CONFIG['epochs']}")
print(f"  Learning rate: {MODEL_CONFIG['learning_rate']}")
print(f"  Dropout rate: {MODEL_CONFIG['dropout_rate']}")

## Load Data

In [None]:
# Load dataset and create dataloaders
print("Loading dataset...")

try:
    dataset = load_species_data()
    train_loader, val_loader, test_loader = create_dataloaders(
        dataset, 
        batch_size=MODEL_CONFIG['batch_size']
    )
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
except Exception as e:
    print(f"Error loading data: {e}")
    print("Please ensure dataset is available")

## Create Model

In [None]:
# Initialize multi-task model
print("Creating multi-task model...")
model = MultiTaskModel(pretrained=True)
model = model.to(device)

# Display model architecture
print("\nModel Architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Define Loss Functions and Optimizer

In [None]:
# Define loss functions
conservation_criterion = nn.CrossEntropyLoss()
geographic_criterion = nn.BCELoss()

# Define optimizer
optimizer = optim.Adam(
    model.parameters(), 
    lr=MODEL_CONFIG['learning_rate']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    patience=TRAINING_CONFIG['reduce_lr_patience'],
    verbose=True
)

print("Loss functions and optimizer configured!")

## Training Loop

In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for images, conservation_labels, geographic_labels in pbar:
        images = images.to(device)
        conservation_labels = conservation_labels.to(device)
        geographic_labels = geographic_labels.to(device).float()
        
        # Forward pass
        optimizer.zero_grad()
        conservation_pred, geographic_pred = model(images)
        
        # Calculate losses
        loss_conservation = conservation_criterion(conservation_pred, conservation_labels)
        loss_geographic = geographic_criterion(geographic_pred, geographic_labels)
        
        # Combined loss
        loss = loss_conservation + 0.5 * loss_geographic
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(train_loader)

# Validation function
def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, conservation_labels, geographic_labels in val_loader:
            images = images.to(device)
            conservation_labels = conservation_labels.to(device)
            geographic_labels = geographic_labels.to(device).float()
            
            conservation_pred, geographic_pred = model(images)
            
            loss_conservation = conservation_criterion(conservation_pred, conservation_labels)
            loss_geographic = geographic_criterion(geographic_pred, geographic_labels)
            loss = loss_conservation + 0.5 * loss_geographic
            
            total_loss += loss.item()
            
            _, predicted = torch.max(conservation_pred, 1)
            correct += (predicted == conservation_labels).sum().item()
            total += conservation_labels.size(0)
    
    accuracy = correct / total
    return total_loss / len(val_loader), accuracy

print("Training functions defined!")

In [None]:
# Main training loop
num_epochs = MODEL_CONFIG['epochs']
best_val_loss = float('inf')
patience_counter = 0

history = {
    'train_loss': [],
    'val_loss': [],
    'val_accuracy': []
}

print(f"Starting training for {num_epochs} epochs...\n")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device)
    
    # Validate
    val_loss, val_accuracy = validate(model, val_loader, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_accuracy)
    
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        save_multi_task_model(model, MODEL_PATHS['multi_task_model'])
        print("âœ“ Model saved!")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= TRAINING_CONFIG['early_stopping_patience']:
        print(f"\nEarly stopping triggered after {epoch+1} epochs")
        break
    
    print("-" * 50)

print("\nTraining completed!")

## Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy plot
axes[1].plot(history['val_accuracy'], label='Validation Accuracy', marker='o', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('../visualizations/training_history.png', dpi=100)
plt.show()

print("Training history saved!")

## Save Final Model

In [None]:
# Save the final model
final_model_path = MODEL_PATHS['multi_task_model']
save_multi_task_model(model, final_model_path)

print(f"Final model saved to: {final_model_path}")
print(f"Best validation loss: {best_val_loss:.4f}")
print("\nModel training complete!")