# Hyperspectral Material Classification - Training (Google Colab)

Optimized for Google Colab Pro+ with A100 GPU

## 1. Check GPU and Install Packages

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install spectral scikit-learn matplotlib tqdm

## 2. Clone Repository

In [None]:
!git clone https://github.com/PlugNawapong/hsi-deeplearning.git
%cd hsi-deeplearning

## 3. Import Modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm

from pipeline_preprocess import HyperspectralPreprocessor
from pipeline_dataset import HyperspectralDataset, create_dataloaders
from pipeline_model import create_model

## 4. Configuration

In [None]:
# MODIFY THESE PATHS to point to your Google Drive data
CONFIG = {
    'data_path': '/content/drive/MyDrive/training_dataset_normalized',
    'labels_path': '/content/drive/MyDrive/Ground_Truth/labels.json',
    'output_dir': '/content/drive/MyDrive/outputs/colab',
    'model_type': 'SpectralCNN1D',
    'num_classes': 7,
    'batch_size': 256,
    'num_epochs': 100,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'num_workers': 4,
    'pin_memory': True,
    'patience': 15,
    'seed': 42
}

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## 5. Load Data

In [None]:
# Load preprocessed data
preprocessor = HyperspectralPreprocessor(CONFIG['data_path'])
data_cube = preprocessor.load_data()

# Load labels
with open(CONFIG['labels_path'], 'r') as f:
    labels_data = json.load(f)

# Create dataloaders
train_loader, val_loader, test_loader, class_weights = create_dataloaders(
    data_cube=data_cube,
    labels_data=labels_data,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    pin_memory=CONFIG['pin_memory']
)

print(f'Training samples: {len(train_loader.dataset)}')
print(f'Validation samples: {len(val_loader.dataset)}')
print(f'Test samples: {len(test_loader.dataset)}')

## 6. Create Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Get number of bands
sample_batch = next(iter(train_loader))
num_bands = sample_batch[0].shape[1]

# Create model
model = create_model(CONFIG['model_type'], num_bands, CONFIG['num_classes'])
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

## 7. Training Loop

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc = 0.0
patience_counter = 0
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(CONFIG['num_epochs']):
    print(f'\nEpoch {epoch+1}/{CONFIG["num_epochs"]}')
    
    # Training
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for spectra, labels in tqdm(train_loader, desc='Training'):
        spectra, labels = spectra.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(spectra)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    train_loss = train_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for spectra, labels in tqdm(val_loader, desc='Validation'):
            spectra, labels = spectra.to(device), labels.to(device)
            outputs = model(spectra)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = val_loss / len(val_loader)
    val_acc = 100. * correct / total
    
    # Update history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # Scheduler
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'config': CONFIG
        }, output_dir / 'best_model.pth')
        print(f'Saved best model (val_acc: {val_acc:.2f}%)')
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f'Early stopping triggered')
        break

print(f'\nBest validation accuracy: {best_val_acc:.2f}%')

## 8. Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history['train_loss'], label='Train')
ax1.plot(history['val_loss'], label='Val')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(history['train_acc'], label='Train')
ax2.plot(history['val_acc'], label='Val')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(output_dir / 'training_history.png', dpi=150)
plt.show()

## 9. Download Results

In [None]:
from google.colab import files
files.download(str(output_dir / 'best_model.pth'))
files.download(str(output_dir / 'training_history.png'))