1️⃣ Data Loader (data_loader.py)



In [None]:
import os
import h5py
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class CrowdDataset(Dataset):
    def __init__(self, root_dir, transform=None):
      #def __init__(self, root_dir: str, transform: Optional[Callable] = None):
        """
        Args:
            root_dir (str): Path to the ShanghaiTech dataset.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        #def __init__(self, root_dir: str, transform: Optional[Callable] = None):


        self.root_dir = root_dir
        self.image_paths = [os.path.join(root_dir, "images", img)
                            for img in os.listdir(os.path.join(root_dir, "images"))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        gt_path = img_path.replace("images", "ground_truth").replace(".jpg", ".h5")

        image = Image.open(img_path).convert("RGB")
        with h5py.File(gt_path, "r") as hf:
            target = np.asarray(hf["density"])

        if self.transform:
            image = self.transform(image)

        return image, target


2️⃣ CSRNet Model (model.py)



In [None]:
import torch
import torch.nn as nn
from torchvision import models

class CSRNet(nn.Module):
    def __init__(self):
        super(CSRNet, self).__init__()
        vgg = models.vgg16_bn(pretrained=True)

        self.frontend = nn.Sequential(*list(vgg.features.children())[:33])
        self.backend = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=2, dilation=2), nn.ReLU(),
            nn.Conv2d(64, 1, 1)
        )

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        return x



3️⃣ Training Script (train.py)


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models


class CSRNet(nn.Module):
    def __init__(self, load_weights=True):
        super(CSRNet, self).__init__()
        vgg = models.vgg16_bn(pretrained=load_weights)
        self.frontend = nn.Sequential(*list(vgg.features.children())[:33])

        self.backend = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )

        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x




*4️⃣* Real-time Inference + Alerts (infer_realtime.py)




In [None]:
%%writefile model.py
import torch
import torch.nn as nn
import torchvision.models as models

class CSRNet(nn.Module):
    def __init__(self, load_weights=True):
        super(CSRNet, self).__init__()
        vgg = models.vgg16_bn(pretrained=load_weights)
        self.frontend = nn.Sequential(*list(vgg.features.children())[:33])
        self.backend = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=2, dilation=2),
            nn.ReLU(inplace=True)
        )
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x


Writing model.py


5️⃣ Streamlit Dashboard (dashboard.py)


In [None]:

!pip install streamlit -q
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb
!dpkg -i cloudflared-linux-amd64.deb


app_code = """
import streamlit as st
import numpy as np
from PIL import Image

st.title("CSRNet Crowd Counting Demo (Colab + Cloudflare)")

uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    # Display image
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_column_width=True)

    # Dummy crowd count (replace with CSRNet later)
    st.success(f"Estimated Crowd Count: {np.random.randint(50,500)}")
else:
    st.info("Please upload an image to start crowd counting.")
"""


with open("app.py", "w") as f:
    f.write(app_code)


!streamlit run app.py &>/dev/null&

import time; time.sleep(5)

!cloudflared tunnel --url http://localhost:8501 --no-autoupdate



[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m54.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m91.7 MB/s[0m eta [36m0:00:00[0m
[?25hSelecting previously unselected package cloudflared.
(Reading database ... 126435 files and directories currently installed.)
Preparing to unpack cloudflared-linux-amd64.deb ...
Unpacking cloudflared (2025.9.0) ...
Setting up cloudflared (2025.9.0) ...
Processing triggers for man-db (2.10.2-1) ...
[90m2025-09-18T12:17:24Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you inten

6️⃣ Install Required Dependencies


In [None]:
# Install required packages
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install h5py opencv-python matplotlib tqdm scipy


7️⃣ Generate Ground Truth Density Maps


In [None]:
import os
import cv2
import numpy as np
import h5py
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

def create_density_map(image_shape, points, sigma=15):
    """Create density map from point annotations"""
    density_map = np.zeros(image_shape, dtype=np.float32)
    
    for point in points:
        x, y = int(point[0]), int(point[1])
        if 0 <= x < image_shape[1] and 0 <= y < image_shape[0]:
            density_map[y, x] = 1.0
    
    # Apply Gaussian filter
    density_map = gaussian_filter(density_map, sigma=sigma)
    return density_map

def generate_synthetic_annotations(image_path, num_people_range=(10, 100)):
    """Generate synthetic point annotations for training"""
    image = cv2.imread(image_path)
    height, width = image.shape[:2]
    
    # Generate random points (simulating people locations)
    num_people = np.random.randint(num_people_range[0], num_people_range[1])
    points = []
    
    for _ in range(num_people):
        x = np.random.randint(0, width)
        y = np.random.randint(0, height)
        points.append([x, y])
    
    return points

def process_dataset(images_dir, output_dir):
    """Process all images and generate ground truth density maps"""
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "ground_truth"), exist_ok=True)
    
    image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    print(f"Processing {len(image_files)} images...")
    
    for i, img_file in enumerate(image_files):
        img_path = os.path.join(images_dir, img_file)
        
        # Load image to get dimensions
        image = cv2.imread(img_path)
        height, width = image.shape[:2]
        
        # Generate synthetic annotations
        points = generate_synthetic_annotations(img_path)
        
        # Create density map
        density_map = create_density_map((height, width), points)
        
        # Save density map as HDF5
        gt_filename = img_file.replace('.jpg', '.h5').replace('.jpeg', '.h5').replace('.png', '.h5')
        gt_path = os.path.join(output_dir, "ground_truth", gt_filename)
        
        with h5py.File(gt_path, 'w') as hf:
            hf['density'] = density_map
        
        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1}/{len(image_files)} images")
    
    print("Ground truth generation completed!")

# Process your images dataset
images_dir = "images"
output_dir = "dataset"

if os.path.exists(images_dir):
    process_dataset(images_dir, output_dir)
    print(f"✅ Ground truth generated in '{output_dir}' directory")
else:
    print(f"❌ Images directory '{images_dir}' not found!")


8️⃣ Create Data Loader for Training


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CrowdDataset(Dataset):
    def __init__(self, root_dir, transform=None, train=True):
        """Dataset for crowd counting"""
        self.root_dir = root_dir
        self.transform = transform
        self.train = train
        
        # Get all image files
        images_dir = os.path.join(root_dir, "images")
        if not os.path.exists(images_dir):
            images_dir = root_dir  # If images are directly in root_dir
        
        self.image_paths = []
        for img_file in os.listdir(images_dir):
            if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                self.image_paths.append(os.path.join(images_dir, img_file))
        
        # Split into train/validation (80/20 split)
        np.random.seed(42)
        indices = np.random.permutation(len(self.image_paths))
        split_idx = int(0.8 * len(self.image_paths))
        
        if train:
            self.image_paths = [self.image_paths[i] for i in indices[:split_idx]]
        else:
            self.image_paths = [self.image_paths[i] for i in indices[split_idx:]]
        
        print(f"{'Training' if train else 'Validation'} dataset: {len(self.image_paths)} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        # Load image
        image = Image.open(img_path).convert("RGB")
        
        # Get corresponding ground truth path
        img_name = os.path.basename(img_path)
        gt_name = img_name.replace('.jpg', '.h5').replace('.jpeg', '.h5').replace('.png', '.h5')
        gt_path = os.path.join(self.root_dir, "ground_truth", gt_name)
        
        # Load density map
        if os.path.exists(gt_path):
            with h5py.File(gt_path, "r") as hf:
                target = np.asarray(hf["density"])
        else:
            # If no ground truth exists, create a dummy one
            target = np.zeros((image.size[1], image.size[0]), dtype=np.float32)
            print(f"Warning: No ground truth found for {img_name}")

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Convert target to tensor
        target = torch.from_numpy(target).float()
        
        return image, target

def get_transforms():
    """Get data augmentation transforms for training and validation"""
    train_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def create_data_loaders(root_dir, batch_size=8, num_workers=4):
    """Create training and validation data loaders"""
    train_transform, val_transform = get_transforms()
    
    train_dataset = CrowdDataset(root_dir, transform=train_transform, train=True)
    val_dataset = CrowdDataset(root_dir, transform=val_transform, train=False)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Test the data loader
root_dir = "dataset"
if os.path.exists(root_dir):
    train_loader, val_loader = create_data_loaders(root_dir, batch_size=4)
    print("✅ Data loaders created successfully!")
else:
    print("❌ Dataset directory not found. Please run the ground truth generation cell first.")


9️⃣ Training Script - Complete Training Pipeline


In [None]:
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

class CrowdCountingLoss(nn.Module):
    """Custom loss function for crowd counting"""
    def __init__(self):
        super(CrowdCountingLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Ensure target has the same spatial dimensions as prediction
        if pred.shape != target.shape:
            target = torch.nn.functional.interpolate(
                target.unsqueeze(1), 
                size=pred.shape[2:], 
                mode='bilinear', 
                align_corners=False
            ).squeeze(1)
        
        return self.mse_loss(pred, target)

def calculate_mae_rmse(pred, target):
    """Calculate Mean Absolute Error and Root Mean Square Error"""
    pred_count = torch.sum(pred, dim=(1, 2, 3))
    target_count = torch.sum(target, dim=(1, 2, 3))
    
    mae = torch.mean(torch.abs(pred_count - target_count))
    rmse = torch.sqrt(torch.mean((pred_count - target_count) ** 2))
    
    return mae.item(), rmse.item()

def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_mae = 0
    total_rmse = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch} [Train]')
    
    for batch_idx, (images, targets) in enumerate(pbar):
        images, targets = images.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Calculate metrics
        mae, rmse = calculate_mae_rmse(outputs, targets)
        
        total_loss += loss.item()
        total_mae += mae
        total_rmse += rmse
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'MAE': f'{mae:.2f}',
            'RMSE': f'{rmse:.2f}'
        })
    
    avg_loss = total_loss / len(train_loader)
    avg_mae = total_mae / len(train_loader)
    avg_rmse = total_rmse / len(train_loader)
    
    return avg_loss, avg_mae, avg_rmse

def validate_epoch(model, val_loader, criterion, device, epoch):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    total_mae = 0
    total_rmse = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f'Epoch {epoch} [Val]')
        
        for batch_idx, (images, targets) in enumerate(pbar):
            images, targets = images.to(device), targets.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, targets)
            
            # Calculate metrics
            mae, rmse = calculate_mae_rmse(outputs, targets)
            
            total_loss += loss.item()
            total_mae += mae
            total_rmse += rmse
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'MAE': f'{mae:.2f}',
                'RMSE': f'{rmse:.2f}'
            })
    
    avg_loss = total_loss / len(val_loader)
    avg_mae = total_mae / len(val_loader)
    avg_rmse = total_rmse / len(val_loader)
    
    return avg_loss, avg_mae, avg_rmse

def save_checkpoint(model, optimizer, epoch, loss, filepath):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)

print("✅ Training functions defined successfully!")


🔟 Start Training the Model


In [None]:
# Training Configuration
config = {
    'batch_size': 8,
    'learning_rate': 1e-4,
    'num_epochs': 50,  # Reduced for faster training
    'num_workers': 2,
    'dataset_path': 'dataset',
    'checkpoint_dir': 'checkpoints',
    'save_interval': 10,
    'patience': 15  # Early stopping patience
}

# Create directories
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

# Create data loaders
print("📊 Loading dataset...")
train_loader, val_loader = create_data_loaders(
    config['dataset_path'], 
    config['batch_size'], 
    config['num_workers']
)

# Initialize model
print("🧠 Initializing CSRNet model...")
model = CSRNet(load_weights=True).to(device)

# Loss function and optimizer
criterion = CrowdCountingLoss()
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5)

# Training history
train_losses, val_losses = [], []
train_maes, val_maes = [], []

best_val_loss = float('inf')
patience_counter = 0

print("🎯 Starting training...")
print("="*60)

import time
start_time = time.time()

for epoch in range(1, config['num_epochs'] + 1):
    # Train
    train_loss, train_mae, train_rmse = train_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_mae, val_rmse = validate_epoch(
        model, val_loader, criterion, device, epoch
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store history
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_maes.append(train_mae)
    val_maes.append(val_mae)
    
    # Print epoch summary
    print(f'Epoch {epoch}/{config["num_epochs"]}:')
    print(f'  Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.2f}, Train RMSE: {train_rmse:.2f}')
    print(f'  Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.2f}, Val RMSE: {val_rmse:.2f}')
    print(f'  Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
    print('-' * 50)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        save_checkpoint(
            model, optimizer, epoch, val_loss,
            os.path.join(config['checkpoint_dir'], 'best_model.pth')
        )
        print(f'✅ New best model saved! Val Loss: {val_loss:.4f}')
    else:
        patience_counter += 1
    
    # Save checkpoint at intervals
    if epoch % config['save_interval'] == 0:
        save_checkpoint(
            model, optimizer, epoch, val_loss,
            os.path.join(config['checkpoint_dir'], f'checkpoint_epoch_{epoch}.pth')
        )
    
    # Early stopping
    if patience_counter >= config['patience']:
        print(f'🛑 Early stopping triggered after {epoch} epochs')
        break

# Save final model
save_checkpoint(
    model, optimizer, epoch, val_loss,
    os.path.join(config['checkpoint_dir'], 'final_model.pth')
)

total_time = time.time() - start_time
print("="*60)
print(f'🎉 Training completed in {total_time/3600:.2f} hours')
print(f'🏆 Best validation loss: {best_val_loss:.4f}')
print(f'💾 Models saved in: {config["checkpoint_dir"]}')


1️⃣1️⃣ Plot Training Results


In [None]:
# Plot training history
def plot_training_history(train_losses, val_losses, train_maes, val_maes):
    """Plot training history"""
    epochs = range(1, len(train_losses) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot MAE
    ax2.plot(epochs, train_maes, 'b-', label='Training MAE', linewidth=2)
    ax2.plot(epochs, val_maes, 'r-', label='Validation MAE', linewidth=2)
    ax2.set_title('Training and Validation MAE', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('MAE')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot the results
if len(train_losses) > 0:
    plot_training_history(train_losses, val_losses, train_maes, val_maes)
    print("📊 Training plots saved as 'training_history.png'")
else:
    print("❌ No training data to plot. Please run the training cell first.")


1️⃣2️⃣ Test the Trained Model


In [None]:
# Load the best trained model and test it
def load_model_for_inference(checkpoint_path, device):
    """Load trained model for inference"""
    model = CSRNet(load_weights=False).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def test_single_image(model, image_path, device):
    """Test model on a single image"""
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    original_size = image.size
    
    # Apply transforms (same as validation)
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        density_map = model(input_tensor)
        count = torch.sum(density_map).item()
    
    # Convert density map to numpy
    density_map = density_map.squeeze().cpu().numpy()
    
    return {
        'count': count,
        'density_map': density_map,
        'original_image': image,
        'original_size': original_size
    }

def visualize_prediction(result, image_name):
    """Visualize prediction results"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(result['original_image'])
    axes[0].set_title('Original Image', fontsize=12, fontweight='bold')
    axes[0].axis('off')
    
    # Density map
    im = axes[1].imshow(result['density_map'], cmap='hot')
    axes[1].set_title(f'Density Map (Count: {result["count"]:.1f})', fontsize=12, fontweight='bold')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Overlay
    overlay = result['original_image'].copy()
    density_resized = cv2.resize(result['density_map'], result['original_size'])
    density_resized = (density_resized - density_resized.min()) / (density_resized.max() - density_resized.min())
    
    # Create colored overlay
    heatmap = plt.cm.hot(density_resized)[:, :, :3]
    heatmap = (heatmap * 255).astype(np.uint8)
    
    # Blend images
    blended = cv2.addWeighted(np.array(overlay), 0.7, heatmap, 0.3, 0)
    
    axes[2].imshow(blended)
    axes[2].set_title(f'Overlay (Count: {result["count"]:.1f})', fontsize=12, fontweight='bold')
    axes[2].axis('off')
    
    plt.suptitle(f'CSRNet Prediction: {image_name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Test the model
checkpoint_path = "checkpoints/best_model.pth"

if os.path.exists(checkpoint_path):
    print("🎯 Loading trained model...")
    trained_model = load_model_for_inference(checkpoint_path, device)
    
    # Test on a few sample images
    test_images = []
    if os.path.exists("images"):
        image_files = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        test_images = [os.path.join("images", f) for f in image_files[:3]]  # Test first 3 images
    
    if test_images:
        print(f"🧪 Testing on {len(test_images)} sample images...")
        for i, img_path in enumerate(test_images):
            print(f"\n--- Testing Image {i+1}: {os.path.basename(img_path)} ---")
            result = test_single_image(trained_model, img_path, device)
            print(f"Predicted crowd count: {result['count']:.1f}")
            visualize_prediction(result, os.path.basename(img_path))
    else:
        print("❌ No test images found in 'images' directory")
        
else:
    print("❌ Trained model not found. Please run the training cell first.")
    print(f"Expected checkpoint at: {checkpoint_path}")
