In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import h5py
import torch.nn.functional as F
import random

# Set random seed for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(38)

# Dataset definition
class CrowdDataset(Dataset):
    def __init__(self, image_dir, density_dir, transform=None):
        self.image_dir = image_dir
        self.density_dir = density_dir
        self.image_filenames = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        density_path = os.path.join(self.density_dir, self.image_filenames[idx].replace('.jpg', '.h5'))

        image = Image.open(img_path).convert('RGB')
        with h5py.File(density_path, 'r') as hf:
            density_map = np.array(hf['density'])

        if self.transform:
            image = self.transform(image)
        density_map = torch.from_numpy(density_map).unsqueeze(0).float() 

        return image, density_map

# MultiScale Attention with Adaptive Scale Weights
class MultiScaleAttention(nn.Module):
    def __init__(self):
        super(MultiScaleAttention, self).__init__()
        
        # Multi-Scale Feature Extraction with adaptive scale weights
        self.conv1 = nn.Conv2d(512, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(512, 128, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(512, 128, kernel_size=7, padding=3)
        self.conv4 = nn.Conv2d(512, 128, kernel_size=1, padding=0)
        
        # Adaptive scale weights
        self.scale_weights = nn.Parameter(torch.ones(4))

        # Channel Attention Mechanism
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.global_max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 512)
        
        # Enhanced Spatial Attention Mechanism
        self.spatial_conv1 = nn.Conv2d(512, 512, kernel_size=1)
        self.spatial_conv2 = nn.Conv2d(2, 1, kernel_size=7, padding=3)

    def forward(self, x):
        # Multi-Scale Feature Extraction with Adaptive Scale Weighting
        f1 = self.conv1(x)
        f2 = self.conv2(x)
        f3 = self.conv3(x)
        f4 = self.conv4(x)
        
        # Apply scale weights and concatenate features
        f_multi = torch.cat([f * self.scale_weights[i] for i, f in enumerate([f1, f2, f3, f4])], dim=1)
        
        # Channel Attention
        avg_pool = self.global_avg_pool(f_multi).view(f_multi.size(0), -1)
        max_pool = self.global_max_pool(f_multi).view(f_multi.size(0), -1)
        channel_weights = torch.sigmoid(self.fc2(F.relu(self.fc1(avg_pool + max_pool)))).view(f_multi.size(0), 512, 1, 1)
        f_channel = f_multi * channel_weights
        
        # Enhanced Spatial Attention
        f_spatial = F.relu(self.spatial_conv1(f_channel))
        avg_out = torch.mean(f_spatial, dim=1, keepdim=True)
        max_out, _ = torch.max(f_spatial, dim=1, keepdim=True)
        spatial_attention = torch.sigmoid(self.spatial_conv2(torch.cat([avg_out, max_out], dim=1)))
        f_attention = f_spatial * spatial_attention
        
        # Residual Connection
        f_attention += f_channel
        
        return f_attention

# Main model with attention mechanism and dilated regressors
class DConvNet_v1_with_Attention(nn.Module):
    def __init__(self, pretrained=True, num_regressors=5):
        super(DConvNet_v1_with_Attention, self).__init__()
        
        vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        self.features = nn.Sequential(*list(vgg16.features.children())[:23])
        self.features.add_module('pool4', nn.MaxPool2d(kernel_size=2, stride=1, padding=0))
        self.features.add_module('dilated_conv5_1', nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2))
        self.features.add_module('bn5_1', nn.BatchNorm2d(512))
        self.features.add_module('relu5_1', nn.ReLU(inplace=True))
        self.features.add_module('dilated_conv5_2', nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2))
        self.features.add_module('bn5_2', nn.BatchNorm2d(512))
        self.features.add_module('relu5_2', nn.ReLU(inplace=True))
        self.features.add_module('dilated_conv5_3', nn.Conv2d(512, 512, kernel_size=3, padding=2, dilation=2))
        self.features.add_module('bn5_3', nn.BatchNorm2d(512))
        self.features.add_module('relu5_3', nn.ReLU(inplace=True))

        # Unfreeze additional VGG layers for fine-tuning
        for param in self.features[:10].parameters():
            param.requires_grad = True

        self.attention = MultiScaleAttention()

        # Dilated Convolution Regressors
        self.regressors = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(512, 64, kernel_size=1, groups=64),
                nn.ReLU(inplace=True),
                nn.Dropout(0.05),
                nn.Conv2d(64, 64, kernel_size=3, padding=2, dilation=2),
                nn.Conv2d(64, 1, kernel_size=1)
            ) for _ in range(num_regressors)
        ])
        self.regressors.apply(self.init_weights)

    def forward(self, x):
        x = self.features(x)
        x = self.attention(x)
        outputs = [regressor(x) for regressor in self.regressors]
        return outputs

    def init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                m.bias.data.fill_(0.01)

# Improved Negative Correlation Loss with Weighted MSE for Crowded Areas
def negative_correlation_loss(outputs, target, lambda_param=0.0001):
    target = target.to(outputs[0].device)
    mse_loss = nn.MSELoss(reduction='none')
    
    total_mse = sum([mse_loss(F.interpolate(output, size=target.shape[2:], mode='bilinear', align_corners=False), target) for output in outputs]) / len(outputs)

    # Compute weighted loss for crowded areas
    crowded_area_weight = 10
    weight_map = torch.where(target > 0.1, crowded_area_weight, 1.0).to(target.device)
    weighted_mse = (weight_map * total_mse).mean()
    
    # Adaptive lambda: decay over epochs (can adjust for time decay)
    adaptive_lambda = lambda_param
    correlations = []
    for i in range(len(outputs)):
        for j in range(i + 1, len(outputs)):
            o_i = outputs[i].view(-1)
            o_j = outputs[j].view(-1)
            corr = torch.corrcoef(torch.stack([o_i, o_j]))[0, 1]
            correlations.append(corr)

    correlation_penalty = abs(sum(correlations)) / (len(correlations) + 1e-8) if correlations else 0
    return weighted_mse + adaptive_lambda * correlation_penalty

# Optimizer and scheduler
def get_optimizer(model):
    return optim.AdamW([
        {'params': model.features.parameters(), 'lr': 2e-5},
        {'params': model.regressors.parameters(), 'lr': 5e-4}
    ], weight_decay=1e-4) 

def get_scheduler(optimizer):
    return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6, verbose=True)

# Data transformations
data_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Training and evaluation
def train_model(model, train_dataloader, test_dataloader, num_epochs=20, lambda_param=0.0001, save_path='model_checkpoint.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = get_optimizer(model)
    scheduler = get_scheduler(optimizer)
    best_mae = float('inf')
    early_stop_patience = 15
    no_improve_epochs = 0

    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        for images, density_maps in train_dataloader:
            images = images.to(device)
            density_maps = density_maps.to(device)

            optimizer.zero_grad()
            outputs = model(images) 
            loss = negative_correlation_loss(outputs, density_maps, lambda_param)  
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step() 

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}")

        model.eval()
        mae, rmse = evaluate_model(model, test_dataloader)
        if mae < best_mae:
            best_mae = mae
            best_rmse = rmse
            no_improve_epochs = 0
            torch.save(model.state_dict(), save_path)
        else:
            no_improve_epochs += 1

        if no_improve_epochs >= early_stop_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

        scheduler.step()
    print(f" MAE: {best_mae}, Validation RMSE: {best_rmse}")
def custom_collate(batch):
    max_height = max([item[0].shape[1] for item in batch])
    max_width = max([item[0].shape[2] for item in batch])

    resized_images = []
    resized_density_maps = []
    for image, density_map in batch:
        image = F.interpolate(image.unsqueeze(0), size=(max_height, max_width), mode='bilinear', align_corners=False)
        image = image.squeeze(0)

        density_map = F.interpolate(density_map.unsqueeze(0), size=(max_height, max_width), mode='bilinear', align_corners=False)
        density_map = density_map.squeeze(0)

        resized_images.append(image)
        resized_density_maps.append(density_map)

    return torch.stack(resized_images), torch.stack(resized_density_maps)
def evaluate_model(model, dataloader):
    model.eval()
    mae, rmse = 0.0, 0.0
    with torch.no_grad():
        for images, density_maps in dataloader:
            images = images.to(device)
            density_maps = density_maps.to(device)

            outputs = model(images)
            avg_output = sum(outputs) / len(outputs)

            mae += torch.abs(avg_output.sum() - density_maps.sum()).item()
            rmse += ((avg_output.sum() - density_maps.sum()) ** 2).item()

    mae /= len(dataloader)
    rmse = (rmse / len(dataloader)) ** 0.5
    return mae, rmse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset paths (Update these to your local paths)
train_image_dir = '/kaggle/input/shanghaitech-with-people-density-map/ShanghaiTech/part_B/train_data/images'
train_density_dir = '/kaggle/input/shanghaitech-with-people-density-map/ShanghaiTech/part_B/train_data/ground-truth-h5'

test_image_dir = '/kaggle/input/shanghaitech-with-people-density-map/ShanghaiTech/part_B/test_data/images'
test_density_dir = '/kaggle/input/shanghaitech-with-people-density-map/ShanghaiTech/part_B/test_data/ground-truth-h5'

train_dataset = CrowdDataset(train_image_dir, train_density_dir, transform=data_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=custom_collate)

test_dataset = CrowdDataset(test_image_dir, test_density_dir, transform=data_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, collate_fn=custom_collate)

model = DConvNet_v1_with_Attention(pretrained=True)

train_model(model, train_dataloader, test_dataloader, num_epochs=20, lambda_param=0.0001, save_path='model_checkpoint.pth')

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 211MB/s]  


Epoch 1/20, Loss: 0.024170016511343418
Epoch 2/20, Loss: 0.0025231755757704377
Epoch 3/20, Loss: 0.0006837312786956318
Epoch 4/20, Loss: 0.000286365827487316
Epoch 5/20, Loss: 0.0001554872874112334
Epoch 6/20, Loss: 9.792719523829874e-05
Epoch 7/20, Loss: 7.235763008793583e-05
Epoch 8/20, Loss: 5.945373181020841e-05
Epoch 9/20, Loss: 5.209415303397691e-05
Epoch 10/20, Loss: 4.976085558155319e-05
Epoch 11/20, Loss: 4.8361351327912414e-05
Epoch 12/20, Loss: 4.759423762152437e-05
Epoch 13/20, Loss: 4.493484328122577e-05
Epoch 14/20, Loss: 4.138555563258706e-05
Epoch 15/20, Loss: 4.642780600988772e-05
Epoch 16/20, Loss: 4.780822229804471e-05
Epoch 17/20, Loss: 5.586158082223846e-05
Epoch 18/20, Loss: 5.886226113943849e-05
Epoch 19/20, Loss: 6.0089974194852406e-05
Epoch 20/20, Loss: 5.2746081055374814e-05
 MAE: 20.02309754528577, Validation RMSE: 22.565459738731498
