# Inference notebook

### 1. Device selection (CPU / GPU)

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
import zipfile
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tifffile as tiff
from tifffile import imread
from io import BytesIO

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 2. Loading test data

In [None]:
csv_path = "data/dataset_split.csv"
df = pd.read_csv(csv_path)
test_df = df[df['split'] == 'test']
ytest = test_df['EUNIS_cls'].values

tabular_cols = test_df.select_dtypes(include=[np.number]).columns.tolist()
tabular_cols = [c for c in tabular_cols if c != 'EUNIS_cls']

Xtest_tensor = torch.tensor(test_df[tabular_cols].values, dtype=torch.float32)
image_test_ids = test_df['id'].tolist()

### 3. Loading trained model

In [None]:
eval_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

class CombinedTabularImageModel(nn.Module):
    """
    Combines tabular data (MLP) and images (CNN) for classification.
    Possibility to choose the CNN architecture (from the Image defined).
    """
    def __init__(self, tabular_input_dim, n_classes=17, tabular_hidden=[128, 64], image_model='Resnet50', pretrained=True, dropout=0.3):
        super().__init__()

        # --- Tabular branch ---
        layers = []
        input_dim = tabular_input_dim
        for h in tabular_hidden:
            layers.append(nn.Linear(input_dim, h))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            input_dim = h
        self.tabular_mlp = nn.Sequential(*layers)
        tabular_feat_dim = tabular_hidden[-1]

        # --- Image branch --- ONLY KEEP THE BEST MODEL AND DEFINE HERE ALSO
        if image_model == 'Resnet50':
            # Return feature vector instead of logits
            self.image_model = ImageResNet50(n_classes=2048, pretrained=pretrained, freeze_backbone=True)
            self.image_feat_dim = 2048
        elif image_model == 'Hypercolumn':
            self.image_model = ImageHypercolumnResNet(n_classes=2048 + 1024 + 512 + 256, pretrained=pretrained)
            self.image_feat_dim = 2048 + 1024 + 512 + 256
        elif image_model == 'CNNfeature':
            self.image_model = ImageCNNFeatureMLP(n_classes=256, pretrained=pretrained, hidden_dim=256, dropout=dropout)
            self.image_feat_dim = 256
        else:
            raise NotImplementedError(f"{image_model} not supported")

        # --- Fusion classifier ---
        self.classifier = nn.Sequential(nn.Linear(tabular_feat_dim + self.image_feat_dim, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, n_classes))

    def forward(self, tabular_x, image_x):
        # Tabular branch
        tab_feat = self.tabular_mlp(tabular_x)
        
        # Image branch
        img_feat = self.image_model(image_x)
        if img_feat.ndim == 4:  # if backbone returns (B,C,H,W)
            img_feat = F.adaptive_avg_pool2d(img_feat, 1).view(img_feat.size(0), -1)

        # Concatenate features
        combined = torch.cat([tab_feat, img_feat], dim=1)
        out = self.classifier(combined)
        return out

# Instantiate model
input_dim = Xtest_tensor.shape[1]

# replace with your trained model filename
model_name = "Combined_Hypercolumn_1" 
selected_image_model = 'Hypercolumn'

model = CombinedTabularImageModel(tabular_input_dim=input_dim, n_classes=17, image_model=selected_image_model)
model.load_state_dict(torch.load(f"models/{model_name}.pt", map_location=device))
model.to(device)
model.eval()

### 4. Selecting 5 test samples
We select the first 5 samples from the test set for reproducible inference.

In [None]:
def load_image_from_zip(zip_path, img_id):
    """
    Load an RGB aerial image from a zip archive using tifffile (robust for GeoTIFF).
    Returns a PIL Image in RGB format.
    """
    with zipfile.ZipFile(zip_path, 'r') as archive:
        with archive.open(f"{img_id}.tif") as file:
            img_array = imread(BytesIO(file.read()))

    # Ensure uint8 format for visualization / PIL
    if img_array.dtype != np.uint8:
        img_array = img_array.astype(np.uint8)

    # Some TIFFs may be (H, W, C) already
    # Ensure RGB
    if img_array.ndim == 3 and img_array.shape[2] >= 3:
        img_array = img_array[:, :, :3]
    else:
        raise ValueError(f"Unexpected image shape for {img_id}: {img_array.shape}")

    return Image.fromarray(img_array, mode="RGB")

num_samples = 5
sample_indices = range(num_samples)
true_labels = ytest[:num_samples]

tabular_samples = Xtest_tensor[sample_indices].to(device)

image_samples = []
zip_path = "data/images.zip"
image_ids_subset = image_test_ids[:num_samples]

for img_id in image_ids_subset:
    img = load_image_from_zip(zip_path, img_id)
    img_tensor = eval_image_transform(img)
    image_samples.append(img_tensor)

image_samples = torch.stack(image_samples).to(device)

### 5. Running inference

In [None]:
with torch.no_grad():
    outputs = model(tabular_samples, image_samples)
    pred_labels = torch.argmax(outputs, dim=1).cpu().numpy()
    print(f"Predicted labels for {num_samples} samples: {pred_labels}")

### 6. Visualizing predictions vs true labels

In [None]:
fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    for i, ax in enumerate(axes):
        img_array = tiff.imread(zip_ref.open(f"{img_id}.tif"))

        # Ensure 3-channel RGB
        if img_array.ndim == 2:
            img_array = np.stack([img_array]*3, axis=-1)
        elif img_array.shape[0] in [1, 3]:  # channel-first
            img_array = np.transpose(img_array, (1, 2, 0))

        img = Image.fromarray(img_array.astype(np.uint8))
        
        ax.imshow(img)
        color = "green" if pred_labels[i] == true_labels[i] else "red"
        ax.set_title(f"Pred: {pred_labels[i]}\nTrue: {true_labels[i]}", color=color)
        ax.axis('off')
plt.show()