# CogRRG Inference Demo

This notebook demonstrates inference with the trained multi-view classifier.

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T

from models import build_classifier
from evaluation import CHEXPERT_LABELS

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT_PATH = '../checkpoints/best.pt'
print(f'Using device: {DEVICE}')

## Load Model

In [None]:
model = build_classifier(backbone='convnext_tiny', pretrained=False)

checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model'])
thresholds = checkpoint.get('thresholds', np.full(14, 0.5))

model = model.to(DEVICE)
model.eval()
print('Model loaded successfully')

## Define Inference Function

In [None]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

def predict(frontal_path: str, lateral_path: str = None):
    """Run inference on a CXR study."""
    
    # Load frontal view
    frontal = Image.open(frontal_path).convert('RGB')
    frontal = transform(frontal)
    
    # Load lateral if available
    if lateral_path:
        lateral = Image.open(lateral_path).convert('RGB')
        lateral = transform(lateral)
        view_mask = torch.tensor([1.0, 1.0])
    else:
        lateral = torch.zeros_like(frontal)
        view_mask = torch.tensor([1.0, 0.0])
    
    # Stack views: [1, 2, 3, H, W]
    views = torch.stack([frontal, lateral]).unsqueeze(0).to(DEVICE)
    view_mask = view_mask.unsqueeze(0).to(DEVICE)
    
    # Inference
    with torch.no_grad():
        logits = model(views, view_mask)
        probs = torch.sigmoid(logits).cpu().numpy()[0]
    
    # Apply thresholds
    predictions = probs > thresholds
    
    # Format results
    results = []
    for i, label in enumerate(CHEXPERT_LABELS):
        results.append({
            'label': label,
            'probability': float(probs[i]),
            'positive': bool(predictions[i]),
        })
    
    return sorted(results, key=lambda x: -x['probability'])

## Example Prediction

In [None]:
# Replace with your image path
FRONTAL_IMAGE = '/path/to/frontal.jpg'

results = predict(FRONTAL_IMAGE)

print('\n' + '='*50)
print('Positive Findings:')
print('='*50)
for r in results:
    if r['positive']:
        print(f"  {r['label']:<30} {r['probability']:.3f}")

print('\nNegative Findings:')
print('-'*50)
for r in results:
    if not r['positive']:
        print(f"  {r['label']:<30} {r['probability']:.3f}")