In [13]:
# # Model Save and Reload


In [14]:
# ## 1. Import Libraries


In [15]:
import torch
import torchvision
import torchvision.models as models
from pathlib import Path
import os


In [16]:
# ## 2. Define Model Architectures


In [17]:
def build_resnet50_classifier(num_classes=23):
    """Build ResNet50 model architecture"""
    model = models.resnet50(weights=None)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model

def build_efficientnet_classifier(num_classes=23):
    """Build EfficientNet-B0 model architecture"""
    model = models.efficientnet_b0(weights=None)
    model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
    return model


In [18]:
# ## 3. Save Model Function


In [19]:
def save_model_with_metadata(model, history, model_name, num_classes, class_to_idx):
    """
    Save model with complete metadata.

    Args:
        model: Trained PyTorch model
        history: Dictionary with training history (train_loss, train_acc, val_loss, val_acc)
        model_name: Name of the model (e.g., 'resnet50', 'efficientnet_b0')
        num_classes: Number of classes
        class_to_idx: Dictionary mapping class names to indices
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'history': history,
        'num_classes': num_classes,
        'class_to_idx': class_to_idx,
        'model_name': model_name
    }
    filename = f'{model_name}_complete.pth'
    torch.save(checkpoint, filename)
    print(f"✓ Saved {model_name} to '{filename}'")

# Example usage (from training notebook):
# save_model_with_metadata(resnet_model, resnet_history, 'resnet50',
#                          len(idx_to_class), class_to_idx)
# save_model_with_metadata(efficientnet_model, efficientnet_history, 'efficientnet_b0',
#                          len(idx_to_class), class_to_idx)


In [20]:
# ## 4. Load Model Function


In [21]:
def load_model(model_path, device='cpu'):
    """
    Load a trained PyTorch model with backward compatibility.

    Supports two formats:
    1. Complete checkpoint (with metadata: model_state_dict, history, class_to_idx, etc.)
    2. Legacy checkpoint (weights only: state_dict)

    Args:
        model_path: Path to the saved model checkpoint
        device: Device to load the model on ('cpu' or 'cuda')

    Returns:
        tuple: (model, class_to_idx, idx_to_class)
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")

    print(f"Loading model from: {model_path}")

    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)

    # Check if it's a complete checkpoint or just weights
    is_complete = isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint

    if is_complete:
        # ========== Complete checkpoint format ==========
        print("→ Detected: Complete checkpoint (with metadata)")

        # Extract metadata
        num_classes = checkpoint['num_classes']
        model_name = checkpoint['model_name']
        class_to_idx = checkpoint['class_to_idx']

        # Build model architecture
        if 'resnet' in model_name.lower():
            model = build_resnet50_classifier(num_classes)
        elif 'efficientnet' in model_name.lower():
            model = build_efficientnet_classifier(num_classes)
        else:
            raise ValueError(f"Unknown model type: {model_name}")

        # Load trained weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()

        # Create idx_to_class mapping
        idx_to_class = {v: k for k, v in class_to_idx.items()}

        print(f"✓ Loaded {model_name} with {num_classes} classes")

    else:
        # ========== Legacy checkpoint format (weights only) ==========
        print("→ Detected: Legacy checkpoint (weights only)")

        # Infer num_classes from the weights
        if 'fc.weight' in checkpoint:  # ResNet50
            num_classes = checkpoint['fc.weight'].shape[0]
            model_type = 'resnet50'
            print(f"→ Detected model type: ResNet50")
        elif 'classifier.1.weight' in checkpoint:  # EfficientNet
            num_classes = checkpoint['classifier.1.weight'].shape[0]
            model_type = 'efficientnet'
            print(f"→ Detected model type: EfficientNet-B0")
        else:
            raise ValueError("Cannot determine model type from checkpoint keys")

        # Build model architecture
        if model_type == 'resnet50':
            model = build_resnet50_classifier(num_classes)
        else:
            model = build_efficientnet_classifier(num_classes)

        # Load weights
        model.load_state_dict(checkpoint)
        model.to(device)
        model.eval()

        # Reconstruct class mapping from CLEAN_DIR
        print(f"→ Reconstructing class mappings from dataset...")
        import torchvision

        # Try to find CLEAN_DIR
        ROOT = Path().resolve()
        CLEAN_DIR = os.path.join(ROOT, "data", "clean")

        if os.path.exists(CLEAN_DIR):
            clean_dataset = torchvision.datasets.ImageFolder(root=CLEAN_DIR)
            class_to_idx = clean_dataset.class_to_idx
            idx_to_class = {v: k for k, v in class_to_idx.items()}
            print(f"✓ Reconstructed class mappings from {CLEAN_DIR}")
        else:
            print(f"⚠ Warning: {CLEAN_DIR} not found")
            print(f"→ Creating placeholder class mappings (0-{num_classes-1})")
            class_to_idx = {f"class_{i}": i for i in range(num_classes)}
            idx_to_class = {i: f"class_{i}" for i in range(num_classes)}

        print(f"✓ Loaded legacy checkpoint with {num_classes} classes")

    return model, class_to_idx, idx_to_class


In [22]:
# ## 5. Load Models


In [23]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model paths
ROOT = Path().resolve()
RESNET_MODEL_PATH = os.path.join(ROOT, "best_resnet50.pth")
EFFICIENTNET_MODEL_PATH = os.path.join(ROOT, "best_efficientnet_b0.pth")

# Load models
print("\nLoading ResNet50...")
resnet_model, resnet_class_to_idx, resnet_idx_to_class = load_model(RESNET_MODEL_PATH, device)

print("\nLoading EfficientNet-B0...")
efficientnet_model, efficientnet_class_to_idx, efficientnet_idx_to_class = load_model(EFFICIENTNET_MODEL_PATH, device)

print("\n✓ All models loaded successfully!")


Using device: cuda

Loading ResNet50...
Loading model from: Q:\Users\zgl-7\Source\Repos\DS3000-25fall\best_resnet50.pth
→ Detected: Legacy checkpoint (weights only)
→ Detected model type: ResNet50
→ Reconstructing class mappings from dataset...
✓ Reconstructed class mappings from Q:\Users\zgl-7\Source\Repos\DS3000-25fall\data\clean
✓ Loaded legacy checkpoint with 23 classes

Loading EfficientNet-B0...
Loading model from: Q:\Users\zgl-7\Source\Repos\DS3000-25fall\best_efficientnet_b0.pth
→ Detected: Legacy checkpoint (weights only)
→ Detected model type: EfficientNet-B0
→ Reconstructing class mappings from dataset...
✓ Reconstructed class mappings from Q:\Users\zgl-7\Source\Repos\DS3000-25fall\data\clean
✓ Loaded legacy checkpoint with 23 classes

✓ All models loaded successfully!


In [24]:
# ## 6. Inference Setup


In [25]:
from PIL import Image
import torchvision.transforms as T

# Image preprocessing
IMG_SIZE = 224
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

inference_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

def predict_image(model, image_path, idx_to_class, device, top_k=5):
    """
    Predict the class of an image.

    Args:
        model: Trained PyTorch model
        image_path: Path to the image file
        idx_to_class: Dictionary mapping class indices to class names
        device: Device to run inference on
        top_k: Number of top predictions to return

    Returns:
        dict: Prediction results with top-k classes and probabilities
    """
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = inference_transform(image).unsqueeze(0).to(device)

    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        top_probs, top_indices = torch.topk(probabilities, top_k, dim=1)

    # Convert to class names
    results = []
    for i in range(top_k):
        class_idx = top_indices[0][i].item()
        prob = top_probs[0][i].item()
        class_name = idx_to_class[class_idx]
        results.append({
            'rank': i + 1,
            'class': class_name,
            'probability': prob,
            'confidence': f"{prob * 100:.2f}%"
        })

    return results, image

print("✓ Inference functions ready!")


✓ Inference functions ready!


In [26]:
# ## 7. Example Prediction


In [27]:
# Example: Predict on an image
image_path = "path/to/your/image.jpg"  # Change this to your image path

# Uncomment to use:
# results, image = predict_image(resnet_model, image_path, resnet_idx_to_class, device, top_k=5)
#
# print("\nTop 5 Predictions:")
# for pred in results:
#     print(f"{pred['rank']}. {pred['class']}: {pred['confidence']}")


In [28]:
# ## 8. Export to HTML


In [29]:
# Export this notebook with outputs to HTML
import subprocess
from datetime import datetime

notebook_name = "ModelInference.ipynb"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_html = f"ModelInference_output_{timestamp}.html"

try:
    cmd = f'jupyter nbconvert --to html --execute "{notebook_name}" --output "{output_html}"'
    print(f"Exporting notebook to HTML...")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)

    if result.returncode == 0:
        print(f"✓ Successfully exported to: {output_html}")
    else:
        print(f"Note: Export command available, run manually if needed")
        print(f"Command: jupyter nbconvert --to html {notebook_name}")
except Exception as e:
    print(f"Manual export: File > Download as > HTML (.html)")



Exporting notebook to HTML...
✓ Successfully exported to: ModelInference_output_20251113_055835.html
