In [3]:
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader
import pandas as pd
from chestxray_module.dataset import data_load, transform
import os

# =========================
# Argument parsing
# =========================
def parse_args():
    parser = argparse.ArgumentParser(description="Chest X-ray inference")

    parser.add_argument(
        "--checkpoint",
        type=str,
        default=DEFAULT_CHECKPOINT,
        help=f"Path to model checkpoint (default: {DEFAULT_CHECKPOINT})",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=DEFAULT_BATCH_SIZE,
        help="Batch size for inference",
    )

    parser.add_argument(
        "--input_path",
        type=str,
        required=True,
        help="Path to an image or a directory of images",
    )
    
    parser.add_argument(
    "--output_path",
    type=str,
    default=f"{os.getcwd()}",
    help="Path to save predictions CSV (default: predictions.csv)",
    )
    
    return parser.parse_args()


def resolve_image_paths(input_path):
    input_path = Path(input_path)

    if input_path.is_file():
        if input_path.suffix.lower() not in {".jpg", ".png"}:
            raise ValueError(f"Not a supported image file: {input_path}")
        return input_path.parent

    if input_path.is_dir():
         return input_path


# =========================
# Defaults (reviewer-friendly)
# =========================
DEFAULT_CHECKPOINT = f"{os.getcwd()}/models/best_model.pt"
DEFAULT_BATCH_SIZE = 32
NUM_CLASSES = 3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# =========================
# Model builder
# =========================
def build_model(num_classes: int = NUM_CLASSES) -> nn.Module:
    model = models.densenet121(weights=None)
    in_features = model.classifier.in_features
    model.classifier = nn.Linear(in_features, num_classes)
    return model



# =========================
# Load trained weights
# =========================
def load_model(checkpoint_path: str) -> nn.Module:
    checkpoint_path = Path(checkpoint_path)

    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    model = build_model()
    state = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(state)
    model.to(DEVICE)
    model.eval()

    return model


# =========================
# Inference / prediction
# =========================
@torch.no_grad()
def predict(model: nn.Module, dataloader: DataLoader):
    all_preds = []
    all_probs = []
    all_paths = []

    for batch in dataloader:
        images = batch["image"].to(DEVICE)
        paths = batch["path"]

        logits = model(images)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_preds.extend(preds.cpu().tolist())
        all_probs.extend(probs.cpu().tolist())
        all_paths.extend(paths)

    return all_preds, all_probs, all_paths

# =========================
# Output
# =========================

import pandas as pd

def save_predictions_csv(preds, probs, paths, out_path="predictions.csv"):
    records = []

    for p, prob, path in zip(preds, probs, paths):
        record = {
            "path": str(path),
            "predicted_class": int(p),
        }

        # add per-class probabilities
        for i, v in enumerate(prob):
            record[f"prob_class_{i}"] = float(v)

        records.append(record)

    df = pd.DataFrame(records)
    df.to_csv(out_path, index=False)

    print(f"[INFO] Predictions saved to: {out_path}")


# =========================
# Main entry point
# =========================
def main():
    args = parse_args()

    print(f"[INFO] Using device: {DEVICE}")
    print(f"[INFO] Loading model from: {args.checkpoint}")

    model = load_model(args.checkpoint)
    
    image_paths = resolve_image_paths(args.input_path)
    
    raw_data = data_load(data_dir=args.image_paths, inspect=False)
    transformed_data = transform(raw_data, "test") # test transformation
    
    input_data = DataLoader(
                    transformed_data,
                    batch_size = args.BATCH_SIZE,
                    shuffle=False,
                    num_workers=0,
                    pin_memory=True,
                    )

    preds, probs, paths = predict(model, input_data)
    save_predictions_csv(out_path=f"{args.output_path}/prediction.csv")

    print(f"[DONE] Inference complete on {len(preds)} images /n")
    print(f"Output saved {args.output_path} ")



In [None]:
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader
import pandas as pd
from chestxray_module.dataset import data_load, transform
import os


def resolve_image_paths(input_path):
    """
    If input_path is:
    - directory → return all .jpg/.png images under it recursively
    - image file → return its parent directory
    """
    input_path = Path(input_path)

    if not input_path.exists():
        raise FileNotFoundError(f"Path does not exist: {input_path}")

    # Case 1: directory → return images recursively
    if input_path.is_dir():
        # image_paths = list(input_path.rglob("*.jpg")) + list(input_path.rglob("*.png"))
        #if not image_paths:
            #raise ValueError(f"No images found under directory: {input_path}")
        return input_path

    # Case 2: single image → return parent directory
    if input_path.is_file():
        if input_path.suffix.lower() not in {".jpg", ".png"}:
            raise ValueError(f"Not a supported image file: {input_path}")
        return input_path.parent

    
image_paths = resolve_image_paths("data/interim/cleaned_data/test/normal/normal-1.jpg")
image_paths

raw_data = data_load(data_dir=str(image_paths), inspect=False)
transformed_data = transform(raw_data, "test") # test transformation


from torch.utils.data import Dataset
class adjust(Dataset):
    
    """
    A custom Dataset wrapper that adds 'class' and 'split' metadata based on file paths without modifying image tensors or transforms.

    """
    
    def __init__(self, base_dataset):
        self.base_dataset = base_dataset
        self.paths = base_dataset.data  # original image paths

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image = self.base_dataset[idx]  # tensor, unchanged
        path = Path(self.paths[idx])

        return {
            "image": image,
            "path": str(path),

        }
    

adjusted_data = adjust(transformed_data)
adjusted_data[0]



Transformation done successfully.


{'image': metatensor([[[ 2.2318,  1.9235, -0.4397,  ..., -1.8610, -1.8610, -1.8610],
          [ 2.2318,  1.9064, -0.7993,  ..., -1.8610, -1.8610, -1.8610],
          [ 2.2147,  1.8550, -0.7479,  ..., -1.8610, -1.8610, -1.8610],
          ...,
          [ 2.0263,  1.7865,  1.6153,  ..., -1.7754, -1.7412, -1.7754],
          [ 2.0434,  1.8208,  1.6153,  ..., -0.8849, -0.8507, -0.8164],
          [ 2.0948,  1.8722,  1.6838,  ...,  0.2111,  0.2796,  0.3138]],
 
         [[ 2.4111,  2.0959, -0.3200,  ..., -1.7731, -1.7731, -1.7731],
          [ 2.4111,  2.0784, -0.6877,  ..., -1.7731, -1.7731, -1.7731],
          [ 2.3936,  2.0259, -0.6352,  ..., -1.7731, -1.7731, -1.7731],
          ...,
          [ 2.2010,  1.9559,  1.7808,  ..., -1.6856, -1.6506, -1.6856],
          [ 2.2185,  1.9909,  1.7808,  ..., -0.7752, -0.7402, -0.7052],
          [ 2.2710,  2.0434,  1.8508,  ...,  0.3452,  0.4153,  0.4503]],
 
         [[ 2.6226,  2.3088, -0.0964,  ..., -1.5430, -1.5430, -1.5430],
          [ 2.6