# Hyperspectral Material Classification - Training (Google Colab)

Optimized for Google Colab Pro+ with A100 GPU

## 1. Check GPU and Install Packages

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install spectral scikit-learn matplotlib tqdm

## 2. Clone Repository

In [None]:
!git clone https://github.com/PlugNawapong/hsi-deeplearning.git
%cd hsi-deeplearning

## 3. Import Modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm

# Import pipeline functions and constants
from pipeline_preprocess import load_hyperspectral_cube, preprocess_cube, bin_spatial_labels
from pipeline_dataset import (
    HyperspectralDataset, create_train_val_split, create_dataloaders, load_label_image,
    CLASS_NAMES, NUM_CLASSES  # Import class info from pipeline
)
from pipeline_model import create_model

print('Modules imported successfully!')
print(f'Number of classes: {NUM_CLASSES}')
print(f'Class names: {CLASS_NAMES}')

## 4. Configuration

**IMPORTANT:** All your data should be in the `DeepLearning_Plastics` folder in Google Drive:
- `DeepLearning_Plastics/training_dataset_normalized/`
- `DeepLearning_Plastics/Ground_Truth/labels.json`
- Output will be saved to `DeepLearning_Plastics/outputs/colab/`

In [None]:
# Configuration - Base directory: DeepLearning_Plastics in Google Drive
BASE_DIR = '/content/drive/MyDrive/DeepLearning_Plastics'

CONFIG = {
    # Data paths
    'data_folder': f'{BASE_DIR}/training_dataset_normalized',
    'label_path': f'{BASE_DIR}/Ground_Truth/labels.png',
    'output_dir': f'{BASE_DIR}/outputs/colab',
    
    # Preprocessing (already applied if using normalized data)
    'preprocess': {
        'wavelength_range': None,  # Set to (450, 1000) if not already normalized
        'spatial_binning': None,
        'spectral_binning': None,
    },
    
    # Model settings - num_classes automatically loaded from pipeline_dataset
    'model_type': 'SpectralCNN1D',  # Options: SpectralCNN1D, HybridSN, ResNet1D, SpectralAttentionNet, DeepSpectralCNN
    'num_classes': NUM_CLASSES,  # Automatically loaded from labels.json via pipeline_dataset
    'dropout_rate': 0.5,
    
    # Training hyperparameters - Optimized for A100
    'batch_size': 256,
    'num_epochs': 100,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'val_ratio': 0.2,
    'augment': True,
    
    # Data loader settings - Optimized for Colab
    'num_workers': 4,
    'pin_memory': True,
    
    # Early stopping
    'patience': 15,
    
    # Random seed
    'seed': 42
}

print('Configuration:')
for k, v in CONFIG.items():
    if k != 'preprocess':
        print(f'  {k}: {v}')
print(f'\nClass names: {CLASS_NAMES}')

## 5. Load and Preprocess Data

In [None]:
# Set random seed
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

# Load hyperspectral cube
print('Loading hyperspectral cube...')
cube, wavelengths, header = load_hyperspectral_cube(CONFIG['data_folder'])
print(f'Cube shape: {cube.shape}')

# Preprocess if needed
if any(CONFIG['preprocess'].values()):
    print('Applying preprocessing...')
    cube, wavelengths = preprocess_cube(cube, wavelengths, CONFIG['preprocess'])
    print(f'Processed shape: {cube.shape}')

print(f'Wavelength range: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm')
print(f'Number of bands: {len(wavelengths)}')

## 6. Load Labels and Create Dataset

In [None]:
# Load labels
print('Loading labels...')
labels = load_label_image(CONFIG['label_path'])
print(f'Labels shape: {labels.shape}')

# Apply spatial binning to labels if needed
if CONFIG['preprocess'].get('spatial_binning'):
    bin_size = CONFIG['preprocess']['spatial_binning']
    labels = bin_spatial_labels(labels, bin_size)
    print(f'Applied {bin_size}x{bin_size} spatial binning to labels')

# Create dataset
print('Creating dataset...')
dataset = HyperspectralDataset(cube, labels, augment=CONFIG['augment'])
print(f'Total samples: {len(dataset)}')
print(f'Number of classes: {NUM_CLASSES}')
print(f'Classes: {CLASS_NAMES}')

# Get class weights
class_weights = dataset.get_class_weights()
print(f'Class weights: {[f"{w:.3f}" for w in class_weights]}')

# Split dataset
train_dataset, val_dataset = create_train_val_split(
    dataset, val_ratio=CONFIG['val_ratio'], random_seed=CONFIG['seed']
)
print(f'Train samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')

# Create dataloaders
train_loader, val_loader = create_dataloaders(
    train_dataset, val_dataset,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)
print(f'Batch size: {CONFIG["batch_size"]}')
print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(val_loader)}')

## 7. Create Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Get number of bands
num_bands = cube.shape[2]

# Create model
model = create_model(
    num_bands=num_bands,
    num_classes=CONFIG['num_classes'],
    model_type=CONFIG['model_type'],
    dropout_rate=CONFIG['dropout_rate']
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Model: {CONFIG["model_type"]}')
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

print('✓ Model created and ready for training')

## 8. Training Loop

In [None]:
# Setup
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
patience_counter = 0
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

print(f'Starting training for {CONFIG["num_epochs"]} epochs...\n')

for epoch in range(CONFIG['num_epochs']):
    print(f'Epoch {epoch+1}/{CONFIG["num_epochs"]}')
    
    # Training
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for spectra, labels in tqdm(train_loader, desc='Training'):
        spectra, labels = spectra.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(spectra)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_loss = train_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for spectra, labels in tqdm(val_loader, desc='Validation'):
            spectra, labels = spectra.to(device), labels.to(device)
            outputs = model(spectra)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_acc = 100. * correct / total
    
    # Update history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # Scheduler
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'config': CONFIG,
            'class_names': CLASS_NAMES,
            'num_bands': num_bands
        }, output_dir / 'best_model.pth')
        print(f'✓ Saved best model (val_acc: {val_acc:.2f}%)')
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f'\nEarly stopping triggered after {epoch+1} epochs')
        break
    
    print()

print(f'Training complete!')
print(f'Best validation accuracy: {best_val_acc:.2f}%')

## 9. Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss
ax1.plot(history['train_loss'], label='Train', linewidth=2)
ax1.plot(history['val_loss'], label='Validation', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(history['train_acc'], label='Train', linewidth=2)
ax2.plot(history['val_acc'], label='Validation', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Validation Accuracy', fontsize=14)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'Saved to: {output_dir}/training_history.png')

## 10. Results Summary

In [None]:
print('='*60)
print('TRAINING SUMMARY')
print('='*60)
print(f'Model: {CONFIG["model_type"]}')
print(f'Number of classes: {NUM_CLASSES}')
print(f'Class names: {CLASS_NAMES}')
print(f'Total epochs trained: {len(history["train_loss"])}')
print(f'Best validation accuracy: {best_val_acc:.2f}%')
print(f'Final training accuracy: {history["train_acc"][-1]:.2f}%')
print(f'Final validation accuracy: {history["val_acc"][-1]:.2f}%')
print(f'\nModel saved to: {output_dir}/best_model.pth')
print(f'Training plot saved to: {output_dir}/training_history.png')
print('='*60)

## 11. Download Results (Optional)

Results are already saved in Google Drive, but you can also download them directly:

In [None]:
from google.colab import files

# Uncomment to download
# files.download(str(output_dir / 'best_model.pth'))
# files.download(str(output_dir / 'training_history.png'))

print(f'All results saved to Google Drive: {CONFIG["output_dir"]}')