# Hyperspectral Material Classification - Inference (Google Colab)

Optimized for Google Colab Pro+ with A100 GPU

## 1. Setup Environment

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install spectral scikit-learn matplotlib tqdm

## 2. Clone Repository

In [None]:
!git clone https://github.com/PlugNawapong/hsi-deeplearning.git
%cd hsi-deeplearning

## 3. Import Modules

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm

from pipeline_preprocess import HyperspectralPreprocessor
from pipeline_model import create_model

## 4. Configuration

**IMPORTANT:** All your data should be in the `DeepLearning_Plastics` folder in Google Drive:
- Trained model: `DeepLearning_Plastics/outputs/colab/best_model.pth`
- Inference data: `DeepLearning_Plastics/Inference_dataset1_normalized/`
- Output will be saved to `DeepLearning_Plastics/outputs/inference/`

In [None]:
# Configuration - Base directory: DeepLearning_Plastics in Google Drive
BASE_DIR = '/content/drive/MyDrive/DeepLearning_Plastics'

CONFIG = {
    # Model and data paths
    'checkpoint_path': f'{BASE_DIR}/outputs/colab/best_model.pth',
    'data_path': f'{BASE_DIR}/Inference_dataset1_normalized',
    'output_dir': f'{BASE_DIR}/outputs/inference',
    
    # Inference settings - Optimized for A100
    'batch_size': 1024  # Larger batch for inference (no gradients)
}

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## 5. Load Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load checkpoint
print(f'Loading model from: {CONFIG["checkpoint_path"]}')
checkpoint = torch.load(CONFIG['checkpoint_path'], map_location=device)
train_config = checkpoint['config']

# Create model
model = create_model(
    train_config['model_type'],
    checkpoint['model_state_dict']['conv1.weight'].shape[1],
    train_config['num_classes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f'Loaded model: {train_config["model_type"]}')
print(f'Training validation accuracy: {checkpoint["val_acc"]:.2f}%')

## 6. Load Data

In [None]:
print(f'Loading data from: {CONFIG["data_path"]}')
preprocessor = HyperspectralPreprocessor(CONFIG['data_path'])
data_cube = preprocessor.load_data()

print(f'Data shape: {data_cube.shape}')
height, width, bands = data_cube.shape
print(f'Image size: {height} x {width} pixels')
print(f'Spectral bands: {bands}')

## 7. Run Inference

In [None]:
# Prepare data
pixels = data_cube.reshape(-1, bands)
predictions = np.zeros(height * width, dtype=np.int32)
confidences = np.zeros(height * width, dtype=np.float32)

# Batch inference
batch_size = CONFIG['batch_size']
num_batches = (len(pixels) + batch_size - 1) // batch_size

print(f'Running inference on {len(pixels):,} pixels in {num_batches} batches...')

with torch.no_grad():
    for i in tqdm(range(num_batches), desc='Inference'):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(pixels))
        
        batch = torch.FloatTensor(pixels[start_idx:end_idx]).to(device)
        outputs = model(batch)
        probs = torch.softmax(outputs, dim=1)
        
        max_probs, preds = probs.max(1)
        predictions[start_idx:end_idx] = preds.cpu().numpy()
        confidences[start_idx:end_idx] = max_probs.cpu().numpy()

# Reshape to image
pred_map = predictions.reshape(height, width)
conf_map = confidences.reshape(height, width)

print(f'\nInference complete!')
print(f'Mean confidence: {conf_map.mean():.4f}')
print(f'Min confidence: {conf_map.min():.4f}')
print(f'Max confidence: {conf_map.max():.4f}')

## 8. Visualize Results

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Prediction map
im1 = ax1.imshow(pred_map, cmap='tab10')
ax1.set_title('Prediction Map', fontsize=14)
ax1.axis('off')
cbar1 = plt.colorbar(im1, ax=ax1)
cbar1.set_label('Class', fontsize=12)

# Confidence map
im2 = ax2.imshow(conf_map, cmap='viridis', vmin=0, vmax=1)
ax2.set_title('Confidence Map', fontsize=14)
ax2.axis('off')
cbar2 = plt.colorbar(im2, ax=ax2)
cbar2.set_label('Confidence', fontsize=12)

plt.tight_layout()
plt.show()

## 9. Class Distribution Analysis

In [None]:
unique, counts = np.unique(pred_map, return_counts=True)

# Bar plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.bar(unique, counts)
ax1.set_xlabel('Class', fontsize=12)
ax1.set_ylabel('Pixel Count', fontsize=12)
ax1.set_title('Predicted Class Distribution', fontsize=14)
ax1.grid(True, alpha=0.3)

# Pie chart
ax2.pie(counts, labels=[f'Class {i}' for i in unique], autopct='%1.1f%%')
ax2.set_title('Class Percentage', fontsize=14)

plt.tight_layout()
plt.show()

print('\nClass distribution:')
for cls, cnt in zip(unique, counts):
    print(f'  Class {cls}: {cnt:,} pixels ({100*cnt/predictions.size:.2f}%)')

## 10. Save Results

In [None]:
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

# Save prediction and confidence maps
np.save(output_dir / 'predictions.npy', pred_map)
np.save(output_dir / 'confidences.npy', conf_map)
print(f'Saved: {output_dir}/predictions.npy')
print(f'Saved: {output_dir}/confidences.npy')

# Save statistics
stats = {
    'mean_confidence': float(conf_map.mean()),
    'min_confidence': float(conf_map.min()),
    'max_confidence': float(conf_map.max()),
    'image_size': {'height': int(height), 'width': int(width)},
    'total_pixels': int(predictions.size),
    'class_distribution': {int(k): int(v) for k, v in zip(unique, counts)}
}
with open(output_dir / 'statistics.json', 'w') as f:
    json.dump(stats, f, indent=2)
print(f'Saved: {output_dir}/statistics.json')

print(f'\nAll results saved to Google Drive: {CONFIG["output_dir"]}')

## 11. Download Results (Optional)

Results are already saved in Google Drive, but you can also download them directly:

In [None]:
from google.colab import files

# Uncomment to download
# files.download(str(output_dir / 'predictions.npy'))
# files.download(str(output_dir / 'confidences.npy'))
# files.download(str(output_dir / 'statistics.json'))

print('Files are available in Google Drive')