# Test Notebook for Images (30 FPS)
This notebook loads the saved models (MobileViT, ResNet50, Random Forest) trained on 30fps image data and runs tests on a single selected image file.

In [6]:
# --- User Input ---
# Please provide the relative path to the image file you want to test.
TEST_FILE_PATH = "../NNATT dataset/Albit/30fps/SPECTRAL_Albit113_planar1_30FPS_395mW_30FPS_42_32986us_2025_09_05-08_35_27_785.bmp"

In [7]:
# --- Imports ---
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import numpy as np
import joblib
import timm

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Define Class Names (Must match training order)
# Based on the folders in 'NNATT dataset'
CLASS_NAMES = sorted(['Albit', 'Calcite', 'Dolomit', 'Feldspat', 'Quarz', 'Rhodocrosite', 'Tile'])
NUM_CLASSES = len(CLASS_NAMES)
print(f"Classes: {CLASS_NAMES}")

Using device: mps
Classes: ['Albit', 'Calcite', 'Dolomit', 'Feldspat', 'Quarz', 'Rhodocrosite', 'Tile']


In [8]:
# --- Model Definitions ---

# Model A: ResNet50
def get_resnet50(num_classes):
    print("Initializing ResNet50...")
    # Initialize without weights to minimize overhead, we load state dict anyway
    # But for safety, standard torchvision models might need structure setup
    model = models.resnet50(weights=None) 
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

# Model B: MobileViT (via timm)
def get_mobile_vit(num_classes):
    print("Initializing MobileViT...")
    # Initialize without pretrained weights, since we are loading our own
    model = timm.create_model('mobilevit_xs', pretrained=False, num_classes=num_classes)
    return model

In [9]:
# --- Load Models ---

# Paths to saved models
MODEL_DIR = "../models"
MOBILEVIT_PATH = os.path.join(MODEL_DIR, "model_mobilevit_images_30fps.pth")
RESNET_PATH = os.path.join(MODEL_DIR, "model_resnet50_images_30fps.pth")
RF_PATH = os.path.join(MODEL_DIR, "model_rf_images_30fps.joblib")

# Load MobileViT
mobilevit = get_mobile_vit(NUM_CLASSES)
try:
    if os.path.exists(MOBILEVIT_PATH):
        mobilevit.load_state_dict(torch.load(MOBILEVIT_PATH, map_location=device))
        print("MobileViT loaded successfully.")
    else:
        print(f"MobileViT not found at {MOBILEVIT_PATH}")
except Exception as e:
    print(f"Error loading MobileViT: {e}")
mobilevit.to(device)
mobilevit.eval()

# Load ResNet50
resnet = get_resnet50(NUM_CLASSES)
try:
    if os.path.exists(RESNET_PATH):
        resnet.load_state_dict(torch.load(RESNET_PATH, map_location=device))
        print("ResNet50 loaded successfully.")
    else:
         print(f"ResNet50 not found at {RESNET_PATH}")
except Exception as e:
    print(f"Error loading ResNet50: {e}")
resnet.to(device)
resnet.eval()

# Load Random Forest
try:
    if os.path.exists(RF_PATH):
        rf_clf = joblib.load(RF_PATH)
        print("Random Forest loaded successfully.")
    else:
        print(f"Random Forest not found at {RF_PATH}")
        rf_clf = None
except Exception as e:
    print(f"Error loading Random Forest: {e}")
    rf_clf = None

Initializing MobileViT...
MobileViT loaded successfully.
Initializing ResNet50...
ResNet50 loaded successfully.
Random Forest loaded successfully.


In [10]:
# --- Preprocessing ---

# Same transform as training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def process_image(filepath):
    try:
        image = Image.open(filepath).convert('RGB')
        tensor = transform(image)
        # Add batch dimension: [1, 3, 224, 224]
        return tensor.unsqueeze(0)
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

# --- Inference ---

if os.path.exists(TEST_FILE_PATH):
    print(f"Processing file: {TEST_FILE_PATH}")
    img_tensor = process_image(TEST_FILE_PATH)
    
    if img_tensor is not None:
        img_tensor = img_tensor.to(device)
        
        # 1. MobileViT
        print("\n--- MobileViT Inference ---")
        with torch.no_grad():
            outputs = mobilevit(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs).item()
            print(f"Prediction: {CLASS_NAMES[pred_idx]}")
            print(f"Confidence: {probs[0][pred_idx]:.4f}")
            
        # 2. ResNet50
        print("\n--- ResNet50 Inference ---")
        with torch.no_grad():
            outputs = resnet(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs).item()
            print(f"Prediction: {CLASS_NAMES[pred_idx]}")
            print(f"Confidence: {probs[0][pred_idx]:.4f}")
            
        # 3. Random Forest
        if rf_clf:
            print("\n--- Random Forest Inference ---")
            # Flatten: [1, 3, 224, 224] -> [1, 3*224*224]
            # Need to use cpu numpy array
            flat = img_tensor.cpu().view(1, -1).numpy()
            
            try:
                pred = rf_clf.predict(flat)
                # Ensure pred is treated as integer index
                # Random Forest predicts the LABEL (0, 1..) if trained on integer labels
                print(f"Prediction: {CLASS_NAMES[int(pred[0])]}")
            except Exception as e:
                print(f"RF Inference failed: {e}")
    
else:
    print(f"File not found: {TEST_FILE_PATH}")

Processing file: ../NNATT dataset/Albit/30fps/SPECTRAL_Albit113_planar1_30FPS_395mW_30FPS_42_32986us_2025_09_05-08_35_27_785.bmp

--- MobileViT Inference ---
Prediction: Albit
Confidence: 0.9976

--- ResNet50 Inference ---
Prediction: Albit
Confidence: 1.0000

--- Random Forest Inference ---
Prediction: Albit
