In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_v2_s
from torchvision import transforms
from PIL import Image, ImageOps

from typing import Dict, Union
from pathlib import Path

# --------------------------
# Configuration
# --------------------------
# Map class names to the model's output index
CLASS2IDX = {"Fake": 0, "Real": 1}
# Map the model's output index back to a class name
IDX2CLASS = {v: k for k, v in CLASS2IDX.items()}

# --------------------------
# Model Definition
# --------------------------
def create_model(num_classes: int = 2) -> nn.Module:
    """
    Creates an EfficientNet-V2 Small model with a custom classifier head.
    """
    model = efficientnet_v2_s()
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)
    return model

# --------------------------
# Preprocessing (Transform)
# --------------------------
class ResizePadToSquare:
    """
    A custom transform to resize an image to a square, preserving aspect ratio
    by scaling the longer side to `size` and padding the shorter side.
    """
    def __init__(self, size: int, fill: int = 0, interpolation = Image.BICUBIC):
        self.size = size
        self.fill = fill
        self.interpolation = interpolation

    def __call__(self, img: Image.Image) -> Image.Image:
        if img.mode != "RGB":
            img = img.convert("RGB")
            
        w, h = img.size
        scale = self.size / max(w, h)
        new_w, new_h = max(1, int(round(w * scale))), max(1, int(round(h * scale)))
        img = img.resize((new_w, new_h), self.interpolation)
        
        pad_w = self.size - new_w
        pad_h = self.size - new_h
        left = pad_w // 2
        top = pad_h // 2
        right = pad_w - left
        bottom = pad_h - top
        
        img = ImageOps.expand(img, border=(left, top, right, bottom), fill=self.fill)
        return img

def get_inference_transform(img_size: int) -> transforms.Compose:
    """
    Returns the complete preprocessing pipeline for inference.
    """
    return transforms.Compose([
        ResizePadToSquare(img_size, fill=0, interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# --------------------------
# Inference Class (Single Model)
# --------------------------

class ImagePredictor:
    """
    A class to load a single model and run inference on images.
    """
    def __init__(self, 
                 ckpt_path: str, 
                 device: str = "cuda", 
                 img_size: int = 384):
        """
        Initializes the predictor.
        
        Args:
            ckpt_path: The file path to the .pt or .pth model checkpoint.
            device: The device to run inference on (e.g., "cuda" or "cpu").
            img_size: The square size the image will be resized/padded to.
        """
        self.device = torch.device(device)
        self.img_size = img_size
        self.transform = get_inference_transform(self.img_size)
        
        # Load the single model
        print(f"Loading model from {ckpt_path} onto {self.device}...")
        self.model = self._load_model(ckpt_path)
        print("Model loaded successfully.")

    def _load_model(self, ckpt_path: str) -> nn.Module:
        """Private helper to load the model from its checkpoint."""
        num_classes = len(CLASS2IDX)
        
        if not Path(ckpt_path).exists():
            raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
            
        model = create_model(num_classes)
        try:
            # Load state dict from the checkpoint file
            # Assumes checkpoint is saved as {"model": state_dict, ...}
            sd = torch.load(ckpt_path, map_location=self.device)["model"]
            model.load_state_dict(sd)
            model.to(self.device)
            model.eval()  # Set model to evaluation mode
            return model
        except Exception as e:
            print(f"Error loading checkpoint {ckpt_path}: {e}")
            raise

    @torch.no_grad()
    def predict(self, image_path: str) -> Dict[str, Union[str, int, float]]:
        """
        Runs inference on a single image from a file path.
        
        Args:
            image_path: The file path to the image.
            
        Returns:
            A dictionary containing the prediction and probabilities.
        """
        try:
            # 1. Open and convert image
            img = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error opening image {image_path}: {e}")
            return {"error": str(e)}

        # 2. Preprocess the image
        tensor = self.transform(img).unsqueeze(0).to(self.device)
        
        # 3. Run inference
        with torch.cuda.amp.autocast(enabled=(self.device.type == "cuda")):
            logits = self.model(tensor)
            # Calculate probabilities
            probs = torch.softmax(logits, dim=1)
        
        # 4. Get probabilities for each class
        prob_fake = probs[0, CLASS2IDX["Fake"]].item()
        prob_real = probs[0, CLASS2IDX["Real"]].item()
        
        # 5. Determine final prediction
        pred_idx = torch.argmax(probs, dim=1).item()
        pred_label = IDX2CLASS[pred_idx]
        
        # 6. Format the output
        return {
            "predicted_label": pred_label,
            "predicted_index": pred_idx,
            "probabilities": {
                "Fake": prob_fake,
                "Real": prob_real
            }
        }

In [None]:
import json

# --- Configuration ---
CHECKPOINT_FILE = "bankk_runs_effv2s/your_model_checkpoint.pt"  # <--- CHANGE THIS
TEST_IMAGE = "imgs/my_test_image.jpg"                  # <--- CHANGE THIS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Create a dummy image for testing if it doesn't exist ---
img_path = Path(TEST_IMAGE)
if not img_path.exists():
    print(f"Creating dummy test image at: {img_path}")
    img_path.parent.mkdir(exist_ok=True)
    Image.new('RGB', (500, 500), color = 'red').save(img_path)

# --- Initialize the predictor ---
# This loads the model onto the GPU (or CPU)
try:
    predictor = ImagePredictor(
        ckpt_path=CHECKPOINT_FILE,
        device=DEVICE,
        img_size=384  # Adjust if your model was trained on a different size
    )
    print("\nPredictor is ready.")
except FileNotFoundError as e:
    print(f"\n--- ERROR ---")
    print(f"Could not find model file: {e}")
    print("Please update CHECKPOINT_FILE to the correct path.")
    predictor = None

In [None]:
if predictor:
    # --- Run prediction on your image ---
    print(f"Running prediction on: {TEST_IMAGE}")
    result = predictor.predict(TEST_IMAGE)
    
    # --- Print the result ---
    print("\n--- Prediction Result ---")
    print(json.dumps(result, indent=2))
else:
    print("Predictor was not initialized. Please fix the checkpoint path in Cell 2.")