In [1]:
# %%  
# Cell 1: Imports & Paths

import os, glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn, torch.nn.functional as F, torch.optim as optim

# Point this at your folder of trimmed .npz cubes:
PROCESSED_DIR = 'prepocessing/processed_cubes/'

# List your species and build a label map (must match your saved filenames)
species_list = ['abachi','afromasia','ipe','iroko','merbau','ovangol','padauk','sapelimahonki','tiiki']
label_map    = {sp:i for i,sp in enumerate(species_list)}

# Hyperparams
BATCH_SIZE = 8
LR         = 1e-3
EPOCHS     = 50
DEVICE     = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')

Using device: cpu


In [2]:
# %%  
# Cell 2: Determine a Uniform Crop Size

# Scan all .npz files to find the minimum H and W
files   = glob.glob(os.path.join(PROCESSED_DIR, "*.npz"))
heights = []
widths  = []
for f in files:
    arr = np.load(f)['block']   # shape (H, W, C)
    h, w, _ = arr.shape
    heights.append(h)
    widths.append(w)

crop_h = min(heights)
crop_w = min(widths)
print(f"→ Center‐crop size: {crop_h}×{crop_w}")


→ Center‐crop size: 263×88


In [3]:
# %%  
# Cell 3: Dataset & DataLoader with PCA Reduction to 15 Bands

import joblib

# Load the PCA model you fitted earlier
pca = joblib.load("pca_15bands.joblib")
N_BANDS = pca.n_components_  # should be 15

class ProcessedHSIDatasetPCA(Dataset):
    def __init__(self, data_dir, label_map, crop_h, crop_w):
        self.files     = glob.glob(os.path.join(data_dir, "*.npz"))
        self.label_map = label_map
        self.crop_h    = crop_h
        self.crop_w    = crop_w
        self.pca       = pca

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path  = self.files[idx]
        fname = os.path.basename(path)
        sp    = fname.split('_block')[0]
        y     = self.label_map[sp]

        block = np.load(path)['block']      # (H, W, C_orig)
        H, W, C = block.shape

        # 1) center-crop to (crop_h, crop_w)
        top  = (H - self.crop_h)//2
        left = (W - self.crop_w)//2
        crop = block[top:top+self.crop_h, left:left+self.crop_w, :]

        # 2) flatten & PCA → (crop_h*crop_w, C_orig) → transform → reshape
        flat    = crop.reshape(-1, C)
        reduced = self.pca.transform(flat)  # (crop_h*crop_w, N_BANDS)
        patch   = reduced.reshape(self.crop_h, self.crop_w, N_BANDS)

        # 3) to tensor (C, H, W)
        patch = np.transpose(patch, (2,0,1)).astype(np.float32)
        return torch.from_numpy(patch), torch.tensor(y)

# Instantiate PCA dataset
ds = ProcessedHSIDatasetPCA(PROCESSED_DIR, label_map, crop_h, crop_w)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"Dataset size: {len(ds)}")
sample, label = ds[0]
print(f"One sample: {sample.shape} (bands×H×W), label={label}")


Dataset size: 26
One sample: torch.Size([15, 263, 88]) (bands×H×W), label=7


In [None]:
# Inspect which labels actually occur in your dataset
all_labels = [ ds[i][1].item() for i in range(len(ds)) ]
print("Unique labels in data:", sorted(set(all_labels)))
print("Total classes defined:", len(label_map))

In [None]:
# %%  
# Cell 4: Simple 3D‐CNN Definition

class Simple3DCNN(nn.Module):
    def __init__(self, num_classes, in_bands):
        super().__init__()
        self.features = nn.Sequential(
            # input: (B, 1, in_bands, H, W)
            nn.Conv3d(1, 16, kernel_size=(5,3,3), padding=(2,1,1)),
            nn.ReLU(inplace=True),
            nn.MaxPool3d((2,2,2)),
            nn.Conv3d(16,32,(3,3,3), padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d((1,1,1)),
        )
        self.classifier = nn.Linear(32, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)           # (B,1,Bands,H,W)
        x = self.features(x)         # (B,32,1,1,1)
        x = x.view(x.size(0), -1)    # (B,32)
        return self.classifier(x)    # (B,num_classes)

# infer number of bands from first sample
sample, _ = ds[0]
bands = sample.shape[0]
model = Simple3DCNN(num_classes=len(species_list), in_bands=bands).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)


In [None]:
# %%  
# Cell 5: Training Loop & Validation Check

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss, running_correct, running_total = 0, 0, 0

    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        loss   = F.cross_entropy(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(1)
        running_loss   += loss.item() * y.size(0)
        running_correct+= (preds==y).sum().item()
        running_total  += y.size(0)

    epoch_loss = running_loss / running_total
    epoch_acc  = running_correct / running_total
    print(f"Epoch {epoch}/{EPOCHS} — Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.3f}")


Epoch 1/5 — Loss: 2.2438, Acc: 0.038


In [None]:
# %%  
# Cell 6: Final Metrics

model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for x,y in dl:
        x = x.to(DEVICE)
        p = model(x).argmax(1).cpu().numpy()
        all_preds.append(p)
        all_labels.append(y.numpy())

all_preds  = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

from sklearn.metrics import classification_report, confusion_matrix, cohen_kappa_score

print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=species_list))
print("Overall Accuracy:", (all_preds==all_labels).mean())
print("Cohen's Kappa:", cohen_kappa_score(all_labels, all_preds))
print("Confusion Matrix:\n", confusion_matrix(all_labels, all_preds))
