In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import scipy.io

In [6]:
IMG_HEIGHT = 256 
IMG_WIDTH = 256 
BATCH_SIZE = 8
PATCH_SIZE = 64
NUM_ROUTER_CLASSES = 3

TRAIN_IMAGE_DIR = '/kaggle/input/crowd-dataset/crowd_wala_dataset/train_data/images'
TRAIN_GT_DIR = '/kaggle/input/crowd-dataset/crowd_wala_dataset/train_data/ground_truth'
TEST_IMAGE_DIR = '/kaggle/input/crowd-dataset/crowd_wala_dataset/test_data/images'
TEST_GT_DIR = '/kaggle/input/crowd-dataset/crowd_wala_dataset/test_data/ground_truth'

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
if not os.path.exists(TRAIN_IMAGE_DIR):
    print(f"Warning: {TRAIN_IMAGE_DIR} does not exist.")
if not os.path.exists(TRAIN_GT_DIR):
    print(f"Warning: {TRAIN_GT_DIR} does not exist.")
if not os.path.exists(TEST_IMAGE_DIR):
    print(f"Warning: {TEST_IMAGE_DIR} does not exist.")
if not os.path.exists(TEST_GT_DIR):
    print(f"Warning: {TEST_GT_DIR} does not exist.")

In [9]:
class CustomCrowdDataset(Dataset):
    def __init__(self, img_folder_path, gt_folder_path, img_height, img_width, transform=None):
        self.img_folder_path = img_folder_path
        self.gt_folder_path = gt_folder_path
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform
        self.image_filenames = sorted([f for f in os.listdir(img_folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_filename = self.image_filenames[idx]
        img_path = os.path.join(self.img_folder_path, img_filename)
        
        base_filename = os.path.splitext(img_filename)[0]
        gt_filename = f"GT_{base_filename}.mat" 
        gt_path = os.path.join(self.gt_folder_path, gt_filename)

        if not os.path.exists(gt_path):
            print(f"Warning: GT not found for {img_filename} (expected {gt_path}). Returning dummy data. 🧊")
            img = Image.new('RGB', (self.img_width, self.img_height), color='black')
            dummy_img = self.transform(img) if self.transform else transforms.ToTensor()(img)
            return dummy_img, torch.tensor(0.0, dtype=torch.float32)

        try:
            img = Image.open(img_path).convert("RGB")
            img = img.resize((self.img_width, self.img_height)) 

            data = scipy.io.loadmat(gt_path)
            
            count = 0
            if 'image_info' in data and data['image_info'].size > 0:
                info_entry = data['image_info'][0,0] 
                
                try:
                    points_data = info_entry[0][0][0][0][0] 
                    
                    if isinstance(points_data, np.ndarray):
                        if points_data.size > 0:
                            count = points_data.shape[0]
                        else:
                            count = 0 
                    elif isinstance(points_data, (int, float, np.int_, np.float_)):
                        count = float(points_data)
                    else:
                        print(f"Warning: Unexpected data type after deep access for {gt_filename}: {type(points_data)}. Setting count to 0")

                except (IndexError, KeyError, TypeError) as e:
                    print(f"Error accessing nested 'points' in 'image_info' for {gt_filename}: {e}.")
                    print(f"Please confirm the exact nested structure if not 'image_info[0,0][0][0][0][0]'.")
                    print("Setting count to 0.")
                    count = 0
            else:
                print(f"Warning: 'image_info' not found or is empty in {gt_filename}.")
                print(f"Available top-level keys in {gt_filename}: {data.keys()}")
                print("Setting count to 0.")
                count = 0
            
            if self.transform:
                img = self.transform(img)

            return img, torch.tensor(float(count), dtype=torch.float32)

        except Exception as e:
            print(f"Error processing {img_filename} or its GT: {e}. Returning dummy data. 💥")
            img = Image.new('RGB', (self.img_width, self.img_height), color='black')
            dummy_img = self.transform(img) if self.transform else transforms.ToTensor()(img)
            return dummy_img, torch.tensor(0.0, dtype=torch.float32)

image_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = CustomCrowdDataset(TRAIN_IMAGE_DIR, TRAIN_GT_DIR, IMG_HEIGHT, IMG_WIDTH, transform=image_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = CustomCrowdDataset(TEST_IMAGE_DIR, TEST_GT_DIR, IMG_HEIGHT, IMG_WIDTH, transform=image_transforms)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training set size: {len(train_dataset)} images")
print(f"Test set size: {len(test_dataset)} images")

Training set size: 400 images
Test set size: 316 images


In [10]:
def get_patches(image_tensor, patch_size):
    assert image_tensor.shape[1] % patch_size == 0, 
    assert image_tensor.shape[2] % patch_size == 0, 

    patches = image_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)

    patches = patches.permute(1, 2, 0, 3, 4).contiguous()
    patches = patches.view(-1, image_tensor.shape[0], patch_size, patch_size)
    
    return patches

class RouterNetwork(nn.Module):
    def __init__(self, in_channels, num_router_classes):
        super(RouterNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 

        self.fc1 = nn.Linear(64 * (PATCH_SIZE // 4) * (PATCH_SIZE // 4), 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, num_router_classes) 

    def forward(self, x):
        x = self.pool1(self.relu(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1) 
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x 


class MCNNColumn(nn.Module):

    def __init__(self, in_channels):
        super(MCNNColumn, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2), 
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2) 
        )

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, 20, kernel_size=7, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2), 
            nn.Conv2d(20, 40, kernel_size=7, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2) 
        )

        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, 24, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(24, 48, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2) 
        )

        self.merge_conv = nn.Conv2d(32 + 40 + 48, 80, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)


        self.global_pool = nn.AdaptiveAvgPool2d(1) 
        self.final_fc = nn.Linear(80, 1) 

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x = torch.cat((x1, x2, x3), dim=1)
        x = self.relu(self.merge_conv(x))
        
        x = self.global_pool(x)
        x = x.view(x.size(0), -1) 
        x = self.final_fc(x) 
        return x 

In [12]:
class SwitchMCNN(nn.Module):

    def __init__(self, in_channels=3, patch_size=PATCH_SIZE, num_router_classes=NUM_ROUTER_CLASSES):
        super(SwitchMCNN, self).__init__()
        self.patch_size = patch_size
        self.num_router_classes = num_router_classes

        self.router = RouterNetwork(in_channels, num_router_classes)
    
        self.mcnn_columns = nn.ModuleList([MCNNColumn(in_channels) for _ in range(num_router_classes)])

    def forward(self, x):
        batch_size, channels, img_h, img_w = x.shape

        num_patches_h = img_h // self.patch_size
        num_patches_w = img_w // self.patch_size
        total_patches_per_image = num_patches_h * num_patches_w

        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
 
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()

        patches_flat = patches.view(-1, channels, self.patch_size, self.patch_size)

        router_logits = self.router(patches_flat)

        a, routed_columns_indices = torch.max(router_logits, dim=1)

        final_counts = torch.zeros(batch_size, dtype=torch.float32).to(x.device)

        for i in range(batch_size):
            image_patches_start_idx = i * total_patches_per_image
            image_patches_end_idx = (i + 1) * total_patches_per_image
            
            current_image_patches = patches_flat[image_patches_start_idx : image_patches_end_idx]
            current_image_routed_indices = routed_columns_indices[image_patches_start_idx : image_patches_end_idx]

            image_total_count = torch.tensor(0.0).to(x.device)

            for patch_idx in range(total_patches_per_image):
                patch = current_image_patches[patch_idx].unsqueeze(0) 
                assigned_column_index = current_image_routed_indices[patch_idx].item()

                mcnn_column = self.mcnn_columns[assigned_column_index]

                patch_count_prediction = mcnn_column(patch)
                image_total_count += patch_count_prediction.squeeze()

            final_counts[i] = image_total_count
            
        return final_counts.unsqueeze(1) 

model = SwitchMCNN(in_channels=3, patch_size=PATCH_SIZE, num_router_classes=NUM_ROUTER_CLASSES).to(device)

criterion = nn.L1Loss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Learning rate might need tuning

NUM_EPOCHS = 3
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train() 
    print("Starting training...")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, counts) in enumerate(train_loader):
            images = images.to(device)
            counts = counts.unsqueeze(1).to(device) 

            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, counts)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

            if (i + 1) % 50 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
        
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {epoch_loss:.4f}")

if __name__ == '__main__':
    print(f"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters.")
    train_model(model, train_loader, criterion, optimizer, NUM_EPOCHS)
    model_path = 'switch_mcnn_crowd_counter.pth'
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")



Model initialized with 2612926 trainable parameters.
Starting training...
Epoch [1/3], Step [50/50], Loss: 68.6814
Epoch [1/3] completed. Average Loss: 236.9926
Epoch [2/3], Step [50/50], Loss: 227.8536
Epoch [2/3] completed. Average Loss: 199.1722
Epoch [3/3], Step [50/50], Loss: 157.1553
Epoch [3/3] completed. Average Loss: 196.9798
Model saved to switch_mcnn_crowd_counter.pth


In [13]:
def evaluate_model(model, test_loader, criterion):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0 
    total_mse = 0.0 
    total_samples = 0
    
    print("Starting evaluation...")
    with torch.no_grad(): 
        for images, counts in test_loader:
            images = images.to(device)
            counts = counts.unsqueeze(1).to(device)

            outputs = model(images)

            loss = criterion(outputs, counts)
            total_loss += loss.item() * images.size(0)

            mae_batch = torch.abs(outputs - counts).sum()
            total_mae += mae_batch.item()

            mse_batch = F.mse_loss(outputs, counts, reduction='sum')
            total_mse += mse_batch.item()
            
            total_samples += images.size(0)

    avg_loss = total_loss / total_samples
    avg_mae = total_mae / total_samples
    avg_mse = total_mse / total_samples 

    print(f"Evaluation Complete.")
    print(f"Average MAE: {avg_mae:.4f}")
    print(f"Average MSE: {avg_mse:.4f}")
    
    return avg_loss, avg_mae, avg_mse 


if __name__ == '__main__':

    loaded_model = SwitchMCNN(in_channels=3, patch_size=PATCH_SIZE, num_router_classes=NUM_ROUTER_CLASSES).to(device)

    model_path = 'switch_mcnn_crowd_counter.pth'
    
    try:

        loaded_model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Model loaded successfully from {model_path}")

        criterion_eval = nn.L1Loss() 

        evaluate_model(loaded_model, test_loader, criterion_eval)

    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}. Please ensure the training block was run and the model was saved.")
    except Exception as e:
        print(f"An error occurred during model loading or evaluation: {e}")

Model loaded successfully from switch_mcnn_crowd_counter.pth
Starting evaluation...
Evaluation Complete.
Average MAE: 169.0175
Average MSE: 60315.3825
