# Hyperspectral Material Classification - Inference (Google Colab)

Optimized for Google Colab Pro+ with A100 GPU

## 1. Check GPU and Install Packages

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install spectral scikit-learn matplotlib tqdm

## 2. Clone Repository

In [None]:
!git clone https://github.com/PlugNawapong/hsi-deeplearning.git
%cd hsi-deeplearning

## 3. Import Modules

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm
from collections import Counter

# Import pipeline functions and constants
from pipeline_preprocess import load_hyperspectral_cube, preprocess_cube
from pipeline_dataset import CLASS_NAMES, NUM_CLASSES  # Import class info from pipeline
from pipeline_model import create_model

print('Modules imported successfully!')
print(f'Number of classes: {NUM_CLASSES}')
print(f'Class names: {CLASS_NAMES}')

## 4. Configuration

**IMPORTANT:** All your data should be in the `DeepLearning_Plastics` folder in Google Drive:
- Trained model: `DeepLearning_Plastics/outputs/colab/best_model.pth`
- Inference data: `DeepLearning_Plastics/Inference_dataset1_normalized/` (or dataset2/dataset3)
- Results will be saved to: `DeepLearning_Plastics/outputs/colab/inference_results/`

In [None]:
# Configuration - Base directory: DeepLearning_Plastics in Google Drive
BASE_DIR = '/content/drive/MyDrive/DeepLearning_Plastics'

CONFIG = {
    # Model path
    'model_path': f'{BASE_DIR}/outputs/colab/best_model.pth',
    
    # Inference data - Change this to dataset2 or dataset3 as needed
    'data_folder': f'{BASE_DIR}/Inference_dataset1_normalized',
    
    # Output directory
    'output_dir': f'{BASE_DIR}/outputs/colab/inference_results',
    
    # Inference settings - Optimized for A100
    'batch_size': 1024,  # Large batch size for A100
    'num_workers': 4,
}

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')
print(f'\nClass names (from labels.json): {CLASS_NAMES}')

## 5. Load Trained Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load checkpoint
print(f'\nLoading model from: {CONFIG["model_path"]}')
checkpoint = torch.load(CONFIG['model_path'], map_location=device)

# Get model configuration
model_config = checkpoint['config']
class_names = checkpoint.get('class_names', CLASS_NAMES)  # Use saved names or default to pipeline
num_bands = checkpoint['num_bands']

print(f'Model type: {model_config["model_type"]}')
print(f'Number of classes: {len(class_names)}')
print(f'Classes: {class_names}')
print(f'Number of bands: {num_bands}')
print(f'Validation accuracy: {checkpoint["val_acc"]:.2f}%')

# Create model
model = create_model(
    num_bands=num_bands,
    num_classes=len(class_names),
    model_type=model_config['model_type'],
    dropout_rate=model_config.get('dropout_rate', 0.5)
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print('\n✓ Model loaded and ready for inference')

## 6. Load Inference Data

In [None]:
# Load hyperspectral cube
print(f'Loading hyperspectral cube from: {CONFIG["data_folder"]}')
cube, wavelengths, header = load_hyperspectral_cube(CONFIG['data_folder'])
print(f'Cube shape: {cube.shape}')
print(f'Wavelength range: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm')
print(f'Number of bands: {len(wavelengths)}')

# Apply preprocessing if needed (should match training)
preprocess_config = model_config.get('preprocess', {})
if any(preprocess_config.values()):
    print('\nApplying preprocessing...')
    cube, wavelengths = preprocess_cube(cube, wavelengths, preprocess_config)
    print(f'Processed shape: {cube.shape}')

# Verify band count matches model
if cube.shape[2] != num_bands:
    raise ValueError(f'Band mismatch! Cube has {cube.shape[2]} bands, model expects {num_bands}')

print('\n✓ Data loaded successfully')

## 7. Run Inference

In [None]:
# Reshape for batch processing
height, width, bands = cube.shape
pixels = cube.reshape(-1, bands)
print(f'Total pixels: {len(pixels):,}')

# Run inference in batches
all_predictions = []
all_probabilities = []

print(f'\nRunning inference (batch size: {CONFIG["batch_size"]})...')
with torch.no_grad():
    for i in tqdm(range(0, len(pixels), CONFIG['batch_size'])):
        batch = pixels[i:i+CONFIG['batch_size']]
        batch_tensor = torch.FloatTensor(batch).to(device)
        
        outputs = model(batch_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        predictions = outputs.argmax(dim=1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_probabilities.extend(probabilities.cpu().numpy())

# Convert to arrays and reshape
prediction_map = np.array(all_predictions).reshape(height, width)
probability_maps = np.array(all_probabilities).reshape(height, width, len(class_names))

print('\n✓ Inference complete!')
print(f'Prediction map shape: {prediction_map.shape}')
print(f'Probability maps shape: {probability_maps.shape}')

## 8. Display Results

In [None]:
# Create color map - automatically sized based on number of classes
# Using a colorblind-friendly palette
base_colors = [
    [0, 0, 0],        # Background - Black
    [230, 159, 0],    # Orange
    [86, 180, 233],   # Sky Blue
    [0, 158, 115],    # Bluish Green
    [240, 228, 66],   # Yellow
    [0, 114, 178],    # Blue
    [213, 94, 0],     # Vermillion
    [204, 121, 167],  # Reddish Purple
    [255, 0, 0],      # Red
    [0, 255, 0],      # Green
    [255, 0, 255],    # Magenta
]

# Use only as many colors as we have classes
colors = base_colors[:len(class_names)]

# Create RGB visualization
rgb_map = np.zeros((height, width, 3), dtype=np.uint8)
for class_id, color in enumerate(colors):
    mask = prediction_map == class_id
    rgb_map[mask] = color

# Calculate number of rows needed for probability maps (skip background)
num_prob_maps = len(class_names) - 1  # Skip background
num_cols = 4
num_rows = (num_prob_maps + num_cols) // num_cols  # Ceiling division

# Plot
fig = plt.figure(figsize=(20, 5 * num_rows))
gs = fig.add_gridspec(num_rows, num_cols)

# Prediction map (larger, spans 2 columns)
ax_main = fig.add_subplot(gs[0, :2])
ax_main.imshow(rgb_map)
ax_main.set_title('Classification Map', fontsize=14, fontweight='bold')
ax_main.axis('off')

# Probability maps for each class (skip background)
for i in range(1, len(class_names)):
    idx = i - 1  # Adjust for skipping background
    if idx < 2:
        row = 0
        col = idx + 2
    else:
        row = (idx - 2) // num_cols + 1
        col = (idx - 2) % num_cols
    
    ax = fig.add_subplot(gs[row, col])
    im = ax.imshow(probability_maps[:, :, i], cmap='hot', vmin=0, vmax=1)
    ax.set_title(f'{class_names[i]} Probability', fontsize=12)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

## 9. Statistics

In [None]:
# Count pixels per class
class_counts = Counter(prediction_map.flatten())
total_pixels = prediction_map.size

print('='*60)
print('CLASSIFICATION STATISTICS')
print('='*60)
for class_id in range(len(class_names)):
    count = class_counts.get(class_id, 0)
    percentage = 100 * count / total_pixels
    print(f'{class_names[class_id]:12s}: {count:8,} pixels ({percentage:5.2f}%)')
print('='*60)
print(f'{'Total':12s}: {total_pixels:8,} pixels (100.00%)')
print('='*60)

## 10. Save Results

In [None]:
# Create output directory
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

# Save prediction map
np.save(output_dir / 'prediction_map.npy', prediction_map)
print(f'Saved: {output_dir}/prediction_map.npy')

# Save probability maps
np.save(output_dir / 'probability_maps.npy', probability_maps)
print(f'Saved: {output_dir}/probability_maps.npy')

# Save RGB visualization
plt.figure(figsize=(12, 10))
plt.imshow(rgb_map)
plt.title('Classification Map', fontsize=16, fontweight='bold')
plt.axis('off')

# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=np.array(color)/255, label=name) 
                   for name, color in zip(class_names, colors)]
plt.legend(handles=legend_elements, loc='upper right', fontsize=12)

plt.tight_layout()
plt.savefig(output_dir / 'classification_map.png', dpi=150, bbox_inches='tight')
print(f'Saved: {output_dir}/classification_map.png')

# Save statistics
stats = {
    'class_names': class_names,
    'class_counts': {class_names[i]: int(class_counts.get(i, 0)) for i in range(len(class_names))},
    'total_pixels': int(total_pixels),
    'image_shape': list(prediction_map.shape),
    'model_type': model_config['model_type'],
    'model_accuracy': float(checkpoint['val_acc'])
}

with open(output_dir / 'statistics.json', 'w') as f:
    json.dump(stats, f, indent=2)
print(f'Saved: {output_dir}/statistics.json')

print(f'\n✓ All results saved to: {output_dir}')

## 11. Download Results (Optional)

Results are already saved in Google Drive, but you can also download them directly:

In [None]:
from google.colab import files

# Uncomment to download
# files.download(str(output_dir / 'prediction_map.npy'))
# files.download(str(output_dir / 'probability_maps.npy'))
# files.download(str(output_dir / 'classification_map.png'))
# files.download(str(output_dir / 'statistics.json'))

print(f'All results saved to Google Drive: {CONFIG["output_dir"]}')