In [10]:
# Cell 1: Imports and Setup
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pathlib import Path

# Cell 2: Define TereNet Architecture
class TereNet(nn.Module):
    def __init__(self, num_classes):
        super(TereNet, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.bn_fc1 = nn.BatchNorm1d(512)

        self.fc2 = nn.Linear(512, num_classes)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))

        x = x.view(x.size(0), -1)

        x = self.dropout(F.relu(self.bn_fc1(self.fc1(x))))
        x = self.fc2(x)
        return x

# Cell 3: Define the infer function
def infer(data_dir, model_path):
    """
    Load trained model and predict classes for all images in data_dir.
    Saves results to results.json.
    
    Args:
        data_dir: Directory containing images to classify
        model_path: Path to saved model (.pth file)
    
    Returns:
        Dictionary mapping filenames to predicted class names
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Determine number of classes from the last layer
    if isinstance(checkpoint, torch.nn.Module):
        model = checkpoint
        print("Loaded complete model")
        num_classes = model.fc2.out_features
    else:
        print("Loaded state_dict, reconstructing model...")
        
        # Get number of classes from fc2 layer
        if 'fc2.weight' in checkpoint:
            num_classes = checkpoint['fc2.weight'].shape[0]
        else:
            raise ValueError("Could not determine number of classes from checkpoint")
        
        print(f"Detected {num_classes} classes from model")
        
        # Create model with detected number of classes
        model = TereNet(num_classes=num_classes)
        model.load_state_dict(checkpoint)
        print("Successfully loaded model weights!")
    
    model.eval()
    model.to(device)
    
    # Same transform as used during training/validation
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load class names
    class_names = None
    
    # Try to find class names from q.json
    if os.path.exists('q.json'):
        with open('q.json', 'r') as f:
            class_names = json.load(f)
        print(f"Loaded {len(class_names)} classes from q.json")
    
    # Try to find from data directory
    elif os.path.exists('data_dir'):
        data_path = Path('data')
        class_dirs = sorted([d.name for d in data_path.iterdir() if d.is_dir()])
        if class_dirs:
            class_names = class_dirs
            print(f"Inferred {len(class_names)} classes from data directory")
            print(f"Classes: {class_names[:5]}..." if len(class_names) > 5 else f"Classes: {class_names}")
    


    
    # Get all images from data_dir
    data_path = Path(data_dir)
    if not data_path.exists():
        raise ValueError(f"Directory {data_dir} does not exist")
    
    image_files = [f for f in data_path.iterdir() 
                   if f.suffix.lower() in {'.jpg', '.jpeg', '.png', '.bmp','.webp'}]
    print(f"Found {len(image_files)} images to process")
    
    if len(image_files) == 0:
        print(f"Warning: No images found in {data_dir}")
        return {}
    
    # Predict
    predictions = {}
    with torch.no_grad():
        for i, img_path in enumerate(image_files, 1):
            try:
                # Load and transform image
                image = Image.open(img_path).convert('RGB')
                image_tensor = transform(image).unsqueeze(0).to(device)
                
                # Make prediction
                output = model(image_tensor)
                _, predicted_idx = torch.max(output, 1)
                predicted_class = class_names[predicted_idx.item()]
                
                predictions[img_path.name] = predicted_class
                
                if i % 10 == 0 or i == len(image_files):
                    print(f"Processed {i}/{len(image_files)} images")
                    
            except Exception as e:
                print(f"Error processing {img_path.name}: {e}")
                predictions[img_path.name] = "error"
    
    # Save results
    with open('results.json', 'w') as f:
        json.dump(predictions, f, indent=4)
    
    print(f"\nPredictions saved to results.json")
    return predictions




In [11]:
data_directory = 'data_dir'
model_file = 'terenet_model1.pth'



# Run inference
if Path(data_directory).exists() and Path(model_file).exists():
    results = infer(data_directory, model_file)
    print(f"\n{'='*60}")
    print(f"INFERENCE COMPLETE: Predicted {len(results)} images")
    print(f"{'='*60}")
    
    # Display all predictions with more detail
    print("\nPrediction Results:")
    print(f"{'='*60}")
    for filename, pred_class in results.items():
        # Clean up the class name for display
        display_name = pred_class.replace('_', ' ').title()
        print(f"  ðŸ“· {filename:20s} â†’ {display_name}")
    print(f"{'='*60}")
    



Using device: cpu
Loaded state_dict, reconstructing model...
Detected 42 classes from model
Successfully loaded model weights!
Loaded 42 classes from q.json
Found 2 images to process
Processed 2/2 images

Predictions saved to results.json

INFERENCE COMPLETE: Predicted 2 images

Prediction Results:
  ðŸ“· bart.png             â†’ Data\Archive\Characters Train\Bart Simpson
  ðŸ“· marge_simpson.jpg    â†’ Data\Archive\Characters Train\Marge Simpson


bart.png -> data\archive\characters_train\bart_simpson
marge_simpson.jpg -> data\archive\characters_train\marge_simpson
