<a href="https://colab.research.google.com/github/NOWAYTE/neuromorphic-adas/blob/main/Untitled37.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sys
import zipfile
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
import matplotlib.pyplot as plt
from tqdm import tqdm
import librosa
import h5py
import cv2

!mkdir -p models utils data/raw data/processed trained_models logs

In [2]:
import json
import os

from google.colab import files
uploaded = files.upload()
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json


In [3]:
!gdown --id 1IBxEHX1bO124z9MGqrDjIPFCZmd1fFRk -O n_cars.7z

Downloading...
From (original): https://drive.google.com/uc?id=1IBxEHX1bO124z9MGqrDjIPFCZmd1fFRk
From (redirected): https://drive.google.com/uc?id=1IBxEHX1bO124z9MGqrDjIPFCZmd1fFRk&confirm=t&uuid=894f6a4a-7bda-4c1b-afc4-beb98e734676
To: /content/n_cars.7z
100% 299M/299M [00:06<00:00, 44.7MB/s]


In [4]:
!7z x n_cars.7z -odata/raw/events_ncars


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.00GHz (50653),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan         1 file, 299228901 bytes (286 MiB)

Extracting archive: n_cars.7z
--
Path = n_cars.7z
Type = 7z
Physical Size = 299228901
Headers Size = 238856
Method = LZMA:23
Solid = +
Blocks = 1

  0%      1% 175 - Prophesee_Dataset_n_cars/n-cars_train/background/obj_006566_td.dat                                                                               1% 469 - Prophesee_Dataset_n_cars/n-cars_train/background/obj_006241_td.dat                                                                         

In [5]:
!kaggle datasets download -d breejeshdhar/thermal-image-dataset-for-object-classification -p ./data
with zipfile.ZipFile('./data/thermal-image-dataset-for-object-classification.zip', 'r') as zip_ref:
    zip_ref.extractall('data/raw/thermal')

!kaggle datasets download -d mmoreaux/environmental-sound-classification-50 -p ./data
with zipfile.ZipFile('./data/environmental-sound-classification-50.zip', 'r') as zip_ref:
    zip_ref.extractall('data/raw/audio')

Dataset URL: https://www.kaggle.com/datasets/breejeshdhar/thermal-image-dataset-for-object-classification
License(s): CC-BY-SA-4.0
Downloading thermal-image-dataset-for-object-classification.zip to ./data
 99% 3.23G/3.24G [00:45<00:00, 53.8MB/s]
100% 3.24G/3.24G [00:45<00:00, 77.3MB/s]
Dataset URL: https://www.kaggle.com/datasets/mmoreaux/environmental-sound-classification-50
License(s): CC-BY-NC-SA-4.0
Downloading environmental-sound-classification-50.zip to ./data
 97% 1.38G/1.42G [00:06<00:00, 100MB/s] 
100% 1.42G/1.42G [00:06<00:00, 222MB/s]


In [94]:
!kaggle datasets download -d apemangr/neuromorphic-falling-detection-dataset

with zipfile.ZipFile('neuromorphic-falling-detection-dataset.zip', 'r') as zip_ref:
    zip_ref.extractall('data/raw/events')

Dataset URL: https://www.kaggle.com/datasets/apemangr/neuromorphic-falling-detection-dataset
License(s): GPL-3.0
Downloading neuromorphic-falling-detection-dataset.zip to /content
100% 2.68G/2.69G [02:29<00:00, 238MB/s]
100% 2.69G/2.69G [02:29<00:00, 19.3MB/s]


In [107]:
%%writefile models/hybrid_fusion.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from .thermal_processor import ThermalEncoder

class HybridFusionModel(nn.Module):
    def __init__(self, num_classes=3, thermal_feat_dim=128):
        """
        Hybrid Neuromorphic-Acoustic-Thermal Fusion Model
        num_classes: 0=normal, 1=siren, 2=hazard
        thermal_feat_dim: Dimension of thermal features
        """
        super().__init__()

        # Neuromorphic Processing Branch
        self.event_encoder = nn.Sequential(
            nn.Conv2d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(32 * 65 * 86, 256)
        )

        # Acoustic Processing Branch
        self.audio_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 128)
        )

        # Thermal Processing Branch
        self.thermal_encoder = ThermalEncoder(output_dim=thermal_feat_dim)

        # Feature dimensions for attention
        self.event_dim = 256
        self.audio_dim = 128
        self.thermal_dim = thermal_feat_dim

        # Attention-based Fusion
        self.attention = nn.MultiheadAttention(
            embed_dim=self.event_dim + self.audio_dim + self.thermal_dim,
            num_heads=4,
            batch_first=True
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.event_dim + self.audio_dim + self.thermal_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        # Confidence Head
        self.confidence = nn.Sequential(
            nn.Linear(self.event_dim + self.audio_dim + self.thermal_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, event_input, audio_input, thermal_input=None):
        # Process event data
        batch_size, seq_len = event_input.shape[0], event_input.shape[1]
        event_input = event_input.reshape(-1, *event_input.shape[2:])
        event_features = self.event_encoder(event_input)
        event_features = event_features.reshape(batch_size, seq_len, -1).mean(dim=1)

        # Process audio data
        audio_features = self.audio_encoder(audio_input.unsqueeze(1))

        # Process thermal data if provided
        if thermal_input is not None:
            thermal_features = self.thermal_encoder(thermal_input)
            combined = torch.cat([event_features, audio_features, thermal_features], dim=1)
        else:
            combined = torch.cat([event_features, audio_features], dim=1)

        # Apply attention
        attn_output, _ = self.attention(
            combined.unsqueeze(1),
            combined.unsqueeze(1),
            combined.unsqueeze(1)
        )
        fused = attn_output.squeeze(1)

        # Outputs
        classification = self.classifier(fused)
        confidence = self.confidence(fused)

        return classification, confidence


class EventProcessing(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv3d = nn.Conv3d(2, 16, kernel_size=(3,3,3), padding=1)
        self.pool = nn.MaxPool3d((1,2,2))

    def forward(self, x):  # x: [B, T, C, H, W]
        x = x.permute(0, 2, 1, 3, 4)  # [B, C, T, H, W]
        return self.pool(F.relu(self.conv3d(x)))


class HybridFusion(nn.Module):
    def __init__(self, feat_dims=[256, 128, 64]):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=sum(feat_dims),
            num_heads=4,
            batch_first=True
        )

    def forward(self, event_feats, audio_feats, thermal_feats):
        combined = torch.cat([event_feats, audio_feats, thermal_feats], dim=-1)
        attn_out, _ = self.attn(combined, combined, combined)
        return attn_out

Overwriting models/hybrid_fusion.py


In [8]:
%%writefile models/thermal_processor.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms

class ThermalProcessor:
    def __init__(self, input_size=(224, 224), normalize=True):
        """
        Initialize thermal image processor

        Args:
            input_size (tuple): Target size for resizing images (height, width)
            normalize (bool): Whether to normalize thermal values
        """
        self.input_size = input_size
        self.normalize = normalize
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.ToTensor(),
        ])

    def preprocess(self, thermal_image):
        """
        Preprocess a single thermal image

        Args:
            thermal_image (numpy.ndarray): Input thermal image (single channel)

        Returns:
            torch.Tensor: Preprocessed thermal image tensor
        """
        # Convert to float32 if needed
        if thermal_image.dtype != np.float32:
            thermal_image = thermal_image.astype(np.float32)

        # Normalize to [0, 1] if not already
        if thermal_image.max() > 1.0:
            thermal_image = thermal_image / thermal_image.max()

        # Apply transformations
        tensor = self.transform(thermal_image)

        # Add batch dimension if needed
        if len(tensor.shape) == 3:
            tensor = tensor.unsqueeze(0)

        return tensor

    def add_thermal_noise(self, thermal_image, noise_level=0.1):
        """Add realistic thermal noise to the image"""
        noise = np.random.normal(0, noise_level, thermal_image.shape).astype(np.float32)
        return np.clip(thermal_image + noise, 0, 1)

    def adjust_thermal_contrast(self, thermal_image, alpha=1.0, beta=0.0):
        """Adjust contrast of thermal image"""
        return np.clip(alpha * thermal_image + beta, 0, 1)


class ThermalEncoder(nn.Module):
    def __init__(self, input_channels=1, base_channels=32, output_dim=128):
        """
        CNN-based thermal feature extractor
        """
        super().__init__()
        self.features = nn.Sequential(
            # Initial conv block
            nn.Conv2d(input_channels, base_channels, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # Residual blocks
            self._make_residual_block(base_channels, base_channels * 2, 2),
            self._make_residual_block(base_channels * 2, base_channels * 4, 2),
            self._make_residual_block(base_channels * 4, base_channels * 8, 2),

            # Final pooling
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),

            # Projection head
            nn.Linear(base_channels * 8, output_dim)
        )

    def _make_residual_block(self, in_channels, out_channels, stride):
        """Create a residual block with skip connection"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.features(x)

Writing models/thermal_processor.py


In [99]:
import os
from glob import glob

def explore_specific_directories(base_path):
    """
    Explore specific directories: audio, events/NFDD, and thermal/Thermal Image Dataset
    """
    print("Exploring specific directories...")

    # Check if base path exists
    if not os.path.exists(base_path):
        print(f"Base path {base_path} does not exist!")
        return

    # Explore audio directory
    audio_path = os.path.join(base_path, 'audio')
    if os.path.exists(audio_path):
        print("\n=== AUDIO DIRECTORY ===")
        print("Contents of audio directory:")
        for item in os.listdir(audio_path):
            item_path = os.path.join(audio_path, item)
            if os.path.isdir(item_path):
                print(f"  {item}/")
                # Check if there are subdirectories
                for subitem in os.listdir(item_path):
                    subitem_path = os.path.join(item_path, subitem)
                    if os.path.isdir(subitem_path):
                        print(f"    {subitem}/")
                        # Check one more level
                        for subsubitem in os.listdir(subitem_path):
                            subsubitem_path = os.path.join(subitem_path, subsubitem)
                            if os.path.isdir(subsubitem_path):
                                print(f"      {subsubitem}/")
                            else:
                                print(f"      {subsubitem}")
                    else:
                        print(f"    {subitem}")
            else:
                print(f"  {item}")

        # Look for WAV files and CSV files
        wav_files = glob(os.path.join(audio_path, '**', '*.wav'), recursive=True)
        csv_files = glob(os.path.join(audio_path, '**', '*.csv'), recursive=True)
        print(f"\nFound {len(wav_files)} WAV files and {len(csv_files)} CSV files in audio directory")
        if csv_files:
            print("CSV files found:")
            for csv_file in csv_files:
                print(f"  {os.path.relpath(csv_file, audio_path)}")

    # Explore events/NFDD directory
    events_path = os.path.join(base_path, 'events', 'NFDD')
    if os.path.exists(events_path):
        print("\n=== EVENTS/NFDD DIRECTORY ===")
        print("Contents of events/NFDD directory:")
        for item in os.listdir(events_path):
            item_path = os.path.join(events_path, item)
            if os.path.isdir(item_path):
                print(f"  {item}/")
                # Check subdirectories
                for subitem in os.listdir(item_path):
                    subitem_path = os.path.join(item_path, subitem)
                    if os.path.isdir(subitem_path):
                        print(f"    {subitem}/")
                        # Check for files in the subdirectory
                        files = os.listdir(subitem_path)
                        print(f"      Found {len(files)} files")
                        if files:
                            # Show first 5 files and their extensions
                            extensions = set()
                            for file in files[:5]:
                                ext = os.path.splitext(file)[1]
                                extensions.add(ext)
                                print(f"      {file}")
                            if len(files) > 5:
                                print(f"      ... and {len(files) - 5} more files")
                            print(f"      File extensions: {', '.join(extensions)}")
                    else:
                        print(f"    {subitem}")
            else:
                print(f"  {item}")

        # Count files by type
        h5_files = glob(os.path.join(events_path, '**', '*.h5'), recursive=True)
        dat_files = glob(os.path.join(events_path, '**', '*.dat'), recursive=True)
        aedat_files = glob(os.path.join(events_path, '**', '*.aedat'), recursive=True)
        print(f"\nFound {len(h5_files)} H5 files, {len(dat_files)} DAT files, {len(aedat_files)} AEDAT files")

    # Explore thermal/Thermal Image Dataset directory
    thermal_path = os.path.join(base_path, 'thermal', 'Thermal Image Dataset')
    if os.path.exists(thermal_path):
        print("\n=== THERMAL/Thermal Image Dataset DIRECTORY ===")
        print("Contents of thermal/Thermal Image Dataset directory:")
        for item in os.listdir(thermal_path):
            item_path = os.path.join(thermal_path, item)
            if os.path.isdir(item_path):
                print(f"  {item}/")
                # Check subdirectories
                for subitem in os.listdir(item_path):
                    subitem_path = os.path.join(item_path, subitem)
                    if os.path.isdir(subitem_path):
                        print(f"    {subitem}/")
                        # Check for files in the subdirectory
                        files = os.listdir(subitem_path)
                        image_files = [f for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
                        print(f"      Found {len(files)} files ({len(image_files)} images)")
                        if image_files:
                            # Show first 5 image files
                            for file in image_files[:5]:
                                print(f"      {file}")
                            if len(image_files) > 5:
                                print(f"      ... and {len(image_files) - 5} more images")
                    else:
                        if subitem.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                            print(f"    {subitem} (image)")
                        else:
                            print(f"    {subitem}")
            else:
                if item.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    print(f"  {item} (image)")
                else:
                    print(f"  {item}")

        # Count image files
        image_files = glob(os.path.join(thermal_path, '**', '*.*'), recursive=True)
        image_files = [f for f in image_files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        print(f"\nFound {len(image_files)} image files in thermal directory")

# Run the exploration
base_path = "data/raw"
explore_specific_directories(base_path)

Exploring specific directories...

=== AUDIO DIRECTORY ===
Contents of audio directory:
  utils2.py
  utils.py
  bc_utils.py
  esc50.csv
  audio/
    audio/
      5-201194-A-38.wav
      3-163727-A-3.wav
      1-23094-B-15.wav
      2-117615-B-48.wav
      5-260875-A-35.wav
      5-243448-A-14.wav
      4-167077-C-20.wav
      1-18755-A-4.wav
      3-159346-B-36.wav
      5-181977-A-35.wav
      2-73027-A-10.wav
      1-115521-A-19.wav
      4-152995-A-24.wav
      1-23222-A-19.wav
      3-110536-A-26.wav
      3-118069-B-27.wav
      1-68670-A-34.wav
      5-198891-C-8.wav
      2-122820-B-36.wav
      4-185575-C-20.wav
      2-82274-B-5.wav
      2-72547-C-14.wav
      5-235507-A-44.wav
      2-82274-A-5.wav
      1-20545-A-28.wav
      5-260011-A-34.wav
      2-87780-A-33.wav
      3-102583-B-49.wav
      3-160119-A-15.wav
      2-84943-A-18.wav
      4-173865-B-9.wav
      1-91359-A-11.wav
      4-188191-A-29.wav
      2-85434-A-27.wav
      1-30709-B-23.wav
      44100/
      1-54

In [104]:
import pandas as pd
import os
from glob import glob
import random

def create_balanced_metadata_csv(data_root, output_path, max_samples=None):
    """
    Create unified metadata CSV for hybrid ADAS training using NFDD dataset.
    Combines neuromorphic events (NFDD), audio, and thermal datasets with proper labels.
    """
    print("Creating metadata with improved thermal class detection...")

    # --- EVENT FILES (NFDD) ---
    event_files_fall = glob(os.path.join(data_root, 'events', 'NFDD', '**', 'Fall', '*.h5'), recursive=True)
    event_files_sit = glob(os.path.join(data_root, 'events', 'NFDD', '**', 'Sit', '*.h5'), recursive=True)
    event_files_walk = glob(os.path.join(data_root, 'events', 'NFDD', '**', 'Walk', '*.h5'), recursive=True)

    event_files = event_files_fall + event_files_sit + event_files_walk
    print(f"Found {len(event_files)} event files (NFDD)")
    print(f"Breakdown: {len(event_files_fall)} Fall, {len(event_files_sit)} Sit, {len(event_files_walk)} Walk")

    # --- AUDIO FILES ---
    audio_files = glob(os.path.join(data_root, 'audio', '**', '*.wav'), recursive=True)
    print(f"Found {len(audio_files)} audio files")

    # Load and process audio metadata
    audio_csv_files = glob(os.path.join(data_root, 'audio', '**', '*.csv'), recursive=True)
    audio_metadata_map = {}

    if audio_csv_files:
        print(f"Found audio CSV files: {[os.path.basename(f) for f in audio_csv_files]}")
        try:
            audio_df = pd.read_csv(audio_csv_files[0])
            print(f"Audio CSV columns: {list(audio_df.columns)}")

            # Create a mapping from filename to category
            for _, row in audio_df.iterrows():
                filename = row['filename']
                category = row['category']
                audio_metadata_map[filename] = category

            print(f"Created audio metadata map with {len(audio_metadata_map)} entries")
        except Exception as e:
            print(f"Error reading audio CSV: {e}")

    # --- THERMAL FILES ---
    # Get all thermal image files with improved class detection
    thermal_files = []
    thermal_classes = {}

    # Define the thermal directories to search
    thermal_dirs = [
        os.path.join(data_root, 'thermal', 'Thermal Image Dataset', 'FLIR', 'Train'),
        os.path.join(data_root, 'thermal', 'Thermal Image Dataset', 'FLIR', 'Test'),
        os.path.join(data_root, 'thermal', 'Thermal Image Dataset', 'SeekThermal', 'Train'),
        os.path.join(data_root, 'thermal', 'Thermal Image Dataset', 'SeekThermal', 'Test')
    ]

    for thermal_dir in thermal_dirs:
        if os.path.exists(thermal_dir):
            # Find all subdirectories (these should be the class names)
            for class_name in os.listdir(thermal_dir):
                class_path = os.path.join(thermal_dir, class_name)
                if os.path.isdir(class_path):
                    # Find all image files in this class directory
                    class_files = glob(os.path.join(class_path, '**', '*.*'), recursive=True)
                    class_files = [f for f in class_files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]

                    # Add to our lists
                    thermal_files.extend(class_files)
                    for file_path in class_files:
                        thermal_classes[file_path] = class_name.lower()

    print(f"Found {len(thermal_files)} thermal image files")

    # Print the detected thermal classes
    unique_thermal_classes = set(thermal_classes.values())
    print(f"Detected thermal classes: {unique_thermal_classes}")

    # --- LIMIT SAMPLES ---
    if max_samples is None:
        max_samples = min(len(event_files), len(audio_files), len(thermal_files))
    else:
        max_samples = min(max_samples, len(event_files), len(audio_files), len(thermal_files))
    print(f"Creating metadata for {max_samples} samples")

    # --- IMPROVED CLASS MAPPING ---
    class_mapping = {
        # Event classes (NFDD)
        'fall': 0,      # Anomaly (falling)
        'sit': 1,       # Normal activity
        'walk': 1,      # Normal activity

        # Thermal classes
        'man': 2,
        'woman': 2,
        'child': 2,
        'person': 2,
        'human': 2,
        'car': 3,
        'cat': 4,
        'dog': 4,
        'animal': 4,

        # Audio categories - more comprehensive mapping
        'siren': 0,           # Emergency sound (anomaly)
        'car_horn': 3,        # Vehicle sound
        'engine': 3,          # Vehicle sound
        'train': 3,           # Vehicle sound
        'helicopter': 3,      # Vehicle sound
        'airplane': 3,        # Vehicle sound
        'dog': 4,             # Animal sound
        'cat': 4,             # Animal sound
        'cow': 4,             # Animal sound
        'pig': 4,             # Animal sound
        'sheep': 4,           # Animal sound
        'chirping_birds': 4,  # Animal sound
        'crickets': 4,        # Animal sound
        'frog': 4,            # Animal sound
        'insects': 4,         # Animal sound
        'crow': 4,            # Animal sound
        'rooster': 4,         # Animal sound
        'hen': 4,             # Animal sound
        'rain': 5,            # Weather condition
        'thunderstorm': 5,    # Weather condition
        'wind': 5,            # Weather condition
        'sea_waves': 5,       # Weather condition
        'water_drops': 5,     # Weather condition
        'pouring_water': 5,   # Weather condition
    }

    # Default class for unmapped audio categories
    audio_default_class = 1  # Normal

    metadata = []

    for i in range(max_samples):
        # Cycle through files
        event_path = event_files[i % len(event_files)]
        audio_path = audio_files[i % len(audio_files)]
        thermal_path = thermal_files[i % len(thermal_files)]

        # --- EXTRACT CLASSES ---
        # For NFDD, the class is in the parent directory name
        event_class = os.path.basename(os.path.dirname(event_path)).lower()

        # For thermal, use the class we extracted earlier
        thermal_class = thermal_classes.get(thermal_path, 'unknown')

        # For audio, try to match the filename with the CSV metadata
        audio_filename = os.path.basename(audio_path)
        audio_category = 'unknown'

        # Try exact match first
        if audio_filename in audio_metadata_map:
            audio_category = audio_metadata_map[audio_filename].lower()
        else:
            # Try partial match
            for csv_filename, category in audio_metadata_map.items():
                if csv_filename in audio_filename or audio_filename in csv_filename:
                    audio_category = category.lower()
                    break

        # --- DETERMINE LABEL AND MODALITY ---
        modalities = []

        # Always include event modality
        if event_class in class_mapping:
            modalities.append(('event', class_mapping[event_class]))

        # Include audio modality if mapped, otherwise use default
        if audio_category in class_mapping:
            modalities.append(('audio', class_mapping[audio_category]))
        else:
            modalities.append(('audio', audio_default_class))

        # Include thermal modality if mapped
        if thermal_class in class_mapping:
            modalities.append(('thermal', class_mapping[thermal_class]))

        # Randomly choose a modality to determine the label
        if modalities:
            source_modality, label = random.choice(modalities)
        else:
            label = 1  # default normal
            source_modality = 'unknown'

        metadata.append({
            'sample_id': f'sample_{i:04d}',
            'event_path': os.path.relpath(event_path, data_root),
            'audio_path': os.path.relpath(audio_path, data_root),
            'thermal_path': os.path.relpath(thermal_path, data_root),
            'label': label,
            'event_class': event_class,
            'audio_category': audio_category,
            'thermal_class': thermal_class,
            'modality_source': source_modality
        })

    # --- SAVE METADATA CSV ---
    df = pd.DataFrame(metadata)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    df.to_csv(output_path, index=False)
    print(f"\nBalanced metadata CSV created at {output_path} with {len(df)} samples")

    # --- LABEL DISTRIBUTION ---
    print("\nLabel distribution:")
    label_counts = df['label'].value_counts().sort_index()
    for lbl, count in label_counts.items():
        print(f"  Class {lbl}: {count} samples ({count/len(df)*100:.1f}%)")

    # --- EVENT CLASS DISTRIBUTION ---
    print("\nEvent class distribution:")
    event_class_counts = df['event_class'].value_counts()
    for cls, count in event_class_counts.items():
        print(f"  {cls}: {count} samples")

    # --- AUDIO CATEGORY DISTRIBUTION ---
    print("\nAudio category distribution:")
    audio_category_counts = df['audio_category'].value_counts()
    for cls, count in audio_category_counts.items():
        print(f"  {cls}: {count} samples")

    # --- THERMAL CLASS DISTRIBUTION ---
    print("\nThermal class distribution:")
    thermal_class_counts = df['thermal_class'].value_counts()
    for cls, count in thermal_class_counts.items():
        print(f"  {cls}: {count} samples")

    return df

# --- USAGE ---
data_root = "data/raw"
metadata_path = "data/processed/metadata_balanced.csv"

# Create the metadata
metadata_df = create_balanced_metadata_csv(data_root, metadata_path, max_samples=2000)

print("\nFirst few rows of balanced metadata:")
print(metadata_df.head())

Creating metadata with improved thermal class detection...
Found 1200 event files (NFDD)
Breakdown: 400 Fall, 400 Sit, 400 Walk
Found 6000 audio files
Found audio CSV files: ['esc50.csv']
Audio CSV columns: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take']
Created audio metadata map with 2000 entries
Found 6843 thermal image files
Detected thermal classes: {'car', 'man', 'cat'}
Creating metadata for 1200 samples

Balanced metadata CSV created at data/processed/metadata_balanced.csv with 1200 samples

Label distribution:
  Class 0: 135 samples (11.2%)
  Class 1: 484 samples (40.3%)
  Class 2: 120 samples (10.0%)
  Class 3: 192 samples (16.0%)
  Class 4: 213 samples (17.8%)
  Class 5: 56 samples (4.7%)

Event class distribution:
  fall: 400 samples
  sit: 400 samples
  walk: 400 samples

Audio category distribution:
  engine: 30 samples
  washing_machine: 29 samples
  thunderstorm: 28 samples
  cow: 28 samples
  cat: 28 samples
  pouring_water: 28 samples
  cracklin

In [108]:
%%writefile utils/neuromorphic_loader.py
import h5py
import numpy as np
import torch
import os

class EventDataLoader:
    def __init__(self, time_window=50.0, height=260, width=346):
        """
        Neuromorphic event data loader
        :param time_window: Time window in milliseconds
        """
        self.time_window = time_window * 1000  # Convert to µs
        self.height = height
        self.width = width

    def load_events(self, file_path):
        """
        Load events from file, supporting .h5, .dat, .aedat
        """
        _, ext = os.path.splitext(file_path)
        ext = ext.lower()

        try:
            if ext in ['.h5', '.hdf5']:
                return self._load_events_h5(file_path)
            elif ext in ['.dat', '.txt', '.csv']:
                return self._load_events_dat(file_path)
            elif ext in ['.aedat']:
                return self._load_events_aedat(file_path)
            else:
                return self._load_events_autodetect(file_path)
        except Exception as e:
            print(f"Error loading events from {file_path}: {e}")
            return self._create_dummy_events()

    def _load_events_h5(self, file_path):
        """Load events from HDF5 file"""
        with h5py.File(file_path, 'r') as f:
            return {
                't': np.array(f['events/t']),
                'x': np.array(f['events/x']),
                'y': np.array(f['events/y']),
                'p': np.array(f['events/p'])
            }

    def _load_events_dat(self, file_path):
        """Load N-CARS binary .dat files"""
        # First try binary
        try:
            data = np.fromfile(file_path, dtype=np.int32)
            if len(data) % 4 != 0:
                raise ValueError("Binary .dat file malformed")
            events = {
                't': data[0::4],
                'x': data[1::4],
                'y': data[2::4],
                'p': data[3::4]
            }
            return events
        except:
            # Fallback to text format
            with open(file_path, 'r') as f:
                lines = f.readlines()
            data_start = 0
            for i, line in enumerate(lines):
                if not line.startswith('%'):
                    data_start = i
                    break
            t, x, y, p = [], [], [], []
            for line in lines[data_start:]:
                parts = line.strip().split()
                if len(parts) >= 4:
                    t.append(float(parts[0]))
                    x.append(int(parts[1]))
                    y.append(int(parts[2]))
                    p.append(int(parts[3]))
            return {
                't': np.array(t),
                'x': np.array(x),
                'y': np.array(y),
                'p': np.array(p)
            }

    def _load_events_aedat(self, file_path):
        """Load events from AEDAT file (simplified)"""
        with open(file_path, 'rb') as f:
            data = np.fromfile(f, dtype=np.uint32)
        t = data[0::4]
        x = data[1::4]
        y = data[2::4]
        p = data[3::4]
        return {'t': t, 'x': x, 'y': y, 'p': p}

    def _load_events_autodetect(self, file_path):
        """Try H5 -> DAT -> AEDAT"""
        for loader in [self._load_events_h5, self._load_events_dat, self._load_events_aedat]:
            try:
                return loader(file_path)
            except:
                continue
        print(f"Could not determine format of {file_path}, returning dummy events")
        return self._create_dummy_events()

    def _create_dummy_events(self):
        """Create dummy events for debugging"""
        num_events = 1000
        return {
            't': np.linspace(0, 1000000, num_events),
            'x': np.random.randint(0, self.width, num_events),
            'y': np.random.randint(0, self.height, num_events),
            'p': np.random.randint(0, 2, num_events)
        }

    def events_to_tensor(self, events):
        """Convert events to tensor representation"""
        min_t = np.min(events['t'])
        max_t = np.max(events['t'])
        if min_t == max_t:
            num_frames = 1
        else:
            num_frames = max(1, int(np.ceil((max_t - min_t) / self.time_window)))

        tensor = torch.zeros((num_frames, 2, self.height, self.width))

        for frame_idx in range(num_frames):
            start_t = min_t + frame_idx * self.time_window
            end_t = start_t + self.time_window
            mask = (events['t'] >= start_t) & (events['t'] < end_t)
            frame_events = {k: v[mask] for k, v in events.items()}

            for t, x, y, p in zip(frame_events['t'], frame_events['x'], frame_events['y'], frame_events['p']):
                channel = 0 if p > 0 else 1
                if 0 <= y < self.height and 0 <= x < self.width:
                    tensor[frame_idx, channel, y, x] += 1

        return tensor

    def normalize_events(self, tensor):
        """Normalize event counts per frame"""
        frame_sums = tensor.sum(dim=(1, 2, 3), keepdim=True) + 1e-8
        return tensor / frame_sums


Overwriting utils/neuromorphic_loader.py


In [109]:
%%writefile training/train_utils.py
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import cv2

# Import the necessary processors
from utils.neuromorphic_loader import EventDataLoader
from utils.audio_processor import AudioProcessor
from utils.thermal_processor import ThermalProcessor

class MultiModalDataset(Dataset):
    def __init__(self, metadata_path, data_root, transform=None):
        """
        Dataset class for multimodal data (events, audio, thermal)
        """
        # Read metadata with error handling
        if isinstance(metadata_path, str):
            if not os.path.exists(metadata_path):
                raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
            self.metadata = pd.read_csv(metadata_path)
        elif isinstance(metadata_path, pd.DataFrame):
            self.metadata = metadata_path
        else:
            raise ValueError("metadata_path must be a string path or DataFrame")

        self.data_root = data_root
        self.transform = transform

        # Initialize processors
        self.event_loader = EventDataLoader()
        self.audio_processor = AudioProcessor()
        self.thermal_processor = ThermalProcessor()

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # Get sample metadata
        sample = self.metadata.iloc[idx]

        try:
            # Load event data
            event_path = os.path.join(self.data_root, sample['event_path'])
            events = self.event_loader.load_events(event_path)
            event_data = self.event_loader.events_to_tensor(events)

            # Load audio data
            audio_path = os.path.join(self.data_root, sample['audio_path'])
            audio = self.audio_processor.load_audio(audio_path)
            audio_data = self.audio_processor.extract_features(audio)

            # Load thermal data
            thermal_path = os.path.join(self.data_root, sample['thermal_path'])
            thermal_image = cv2.imread(thermal_path, cv2.IMREAD_GRAYSCALE)
            if thermal_image is None:
                raise ValueError(f"Could not load thermal image: {thermal_path}")
            thermal_data = self.thermal_processor.preprocess(thermal_image)

            # Get label
            label = sample['label']

            return {
                'events': event_data,
                'audio': audio_data,
                'thermal': thermal_data,
                'label': torch.tensor(label, dtype=torch.long)
            }
        except Exception as e:
            print(f"Error loading sample {idx} ({sample['event_path']}): {e}")
            # Return a dummy sample if there's an error
            return self._create_dummy_sample(label=sample['label'])

    def _create_dummy_sample(self, label=0):
        """Create a dummy sample for debugging"""
        return {
            'events': torch.randn(10, 2, 260, 346),
            'audio': torch.randn(64, 64),
            'thermal': torch.randn(1, 224, 224),
            'label': torch.tensor(label, dtype=torch.long)
        }

def create_data_loaders(metadata_path, data_root, batch_size=8, validation_split=0.2):
    """Create training and validation data loaders"""
    # Create dataset
    dataset = MultiModalDataset(metadata_path, data_root)

    # Split into train and validation
    train_size = int((1 - validation_split) * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    return train_loader, val_loader

Overwriting training/train_utils.py


In [120]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import os
import sys

# Add the project root to the path - This line caused the error in Colab notebooks
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models.hybrid_fusion import HybridFusionModel
from training.train_utils import MultiModalDataset, create_data_loaders

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, device='cuda'):
    """
    Training function for the hybrid fusion model
    """
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'confidence': []
    }

    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        running_conf = 0.0

        # Iterate over data
        for batch_idx, batch in enumerate(tqdm(train_loader, desc='Training')):
            # Move data to device
            events = batch['events'].to(device)
            audio = batch['audio'].to(device)
            thermal = batch['thermal'].to(device)
            labels = batch['label'].to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs, confidence = model(events, audio, thermal)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * events.size(0)
            running_corrects += torch.sum(preds == labels.data)
            running_conf += torch.mean(confidence).item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        epoch_conf = running_conf / len(train_loader)

        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.cpu().numpy())
        history['confidence'].append(epoch_conf)

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} Conf: {epoch_conf:.4f}')

        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(val_loader, desc='Validation')):
                events = batch['events'].to(device)
                audio = batch['audio'].to(device)
                thermal = batch['thermal'].to(device)
                labels = batch['label'].to(device)

                outputs, confidence = model(events, audio, thermal)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * events.size(0)
                running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)

        history['val_loss'].append(epoch_loss)
        history['val_acc'].append(epoch_acc.cpu().numpy())

        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Step the scheduler
        scheduler.step()

        # Save best model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'New best model saved with accuracy: {best_acc:.4f}')

        print()

    print(f'Best val Acc: {best_acc:4f}')
    return history

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Paths
    data_root = "data/raw"
    metadata_path = "data/processed/metadata_balanced.csv"

    # Create data loaders
    print("Creating data loaders...")
    train_loader, val_loader = create_data_loaders(
        metadata_path, data_root, batch_size=4, validation_split=0.2
    )

    # Test a single batch
    try:
        batch = next(iter(train_loader))
        print("Batch loaded successfully")
        print(f"Events shape: {batch['events'].shape}")
        print(f"Audio shape: {batch['audio'].shape}")
        print(f"Thermal shape: {batch['thermal'].shape}")
        print(f"Labels: {batch['label']}")
    except Exception as e:
        print(f"Error loading batch: {e}")
        return

    # Initialize model
    model = HybridFusionModel(num_classes=3, thermal_feat_dim=128).to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Train the model
    print("Starting training...")
    history = train_model(
        model, train_loader, val_loader,
        criterion, optimizer, scheduler,
        num_epochs=25, device=device
    )

    print("Training completed!")

    # Save the final model
    torch.save(model.state_dict(), 'final_model.pth')
    print("Model saved as final_model.pth")

if __name__ == "__main__":
    main()

Using device: cuda
Creating data loaders...


TypeError: argument of type 'method' is not iterable

In [None]:
# --- MODEL EXPORT CELL ---
# Export model for local deployment
import json

# Save model configuration
config = {
    "model_type": "HybridFusionModel",
    "num_classes": 3,
    "thermal_feat_dim": 128,
    "audio_params": {
        "sample_rate": 16000,
        "n_fft": 1024,
        "hop_length": 512,
        "n_mels": 64
    },
    "event_params": {
        "time_window": 50.0,
        "height": 260,
        "width": 346
    },
    "thermal_params": {
        "input_size": [224, 224],
        "normalize": True
    }
}

with open('trained_models/model_config.json', 'w') as f:
    json.dump(config, f, indent=4)

# Convert to ONNX format (optional)
dummy_events = torch.randn(1, 10, 2, 260, 346).to(device)
dummy_audio = torch.randn(1, 64, 64).to(device)  # Example mel spectrogram shape
dummy_thermal = torch.randn(1, 1, 224, 224).to(device)

torch.onnx.export(
    model,
    (dummy_events, dummy_audio, dummy_thermal),
    "trained_models/model.onnx",
    input_names=["events", "audio", "thermal"],
    output_names=["classification", "confidence"],
    dynamic_axes={
        "events": {0: "batch_size"},
        "audio": {0: "batch_size"},
        "thermal": {0: "batch_size"},
        "classification": {0: "batch_size"},
        "confidence": {0: "batch_size"}
    }
)

# Create a zip file with all necessary files for local deployment
!zip -r adas_deployment.zip trained_models/ utils/ models/ config.py

# Download the deployment package
from google.colab import files
files.download('adas_deployment.zip')

print("Model exported and deployment package created!")

In [15]:
!mkdir -p training

In [119]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import cv2

# Import the necessary processors
from utils.neuromorphic_loader import EventDataLoader
from utils.audio_processor import AudioProcessor
from utils.thermal_processor import ThermalProcessor

class MultiModalDataset(Dataset):
    def __init__(self, metadata_path, data_root, transform=None):
        """
        Dataset class for multimodal data (events, audio, thermal)
        """
        # Read metadata with error handling
        if isinstance(metadata_path, str):
            if not os.path.exists(metadata_path):
                raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
            with open(metadata_path, 'r') as f:
                self.metadata = pd.read_csv(f)
        elif isinstance(metadata_path, pd.DataFrame):
            self.metadata = metadata_path
        else:
            raise ValueError("metadata_path must be a string path or DataFrame")

        self.data_root = data_root
        self.transform = transform

        # Initialize processors
        self.event_loader = EventDataLoader()
        self.audio_processor = AudioProcessor()
        self.thermal_processor = ThermalProcessor()

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        # Get sample metadata
        sample = self.metadata.iloc[idx]

        try:
            # Load event data
            event_path = os.path.join(self.data_root, sample['event_path'])
            events = self.event_loader.load_events(event_path)
            event_data = self.event_loader.events_to_tensor(events)

            # Load audio data
            audio_path = os.path.join(self.data_root, sample['audio_path'])
            audio = self.audio_processor.load_audio(audio_path)
            audio_data = self.audio_processor.extract_features(audio)

            # Load thermal data
            thermal_path = os.path.join(self.data_root, sample['thermal_path'])
            thermal_image = cv2.imread(thermal_path, cv2.IMREAD_GRAYSCALE)
            if thermal_image is None:
                raise ValueError(f"Could not load thermal image: {thermal_path}")
            thermal_data = self.thermal_processor.preprocess(thermal_image)

            # Get label
            label = sample['label']

            return {
                'events': event_data,
                'audio': audio_data,
                'thermal': thermal_data,
                'label': torch.tensor(label, dtype=torch.long)
            }
        except Exception as e:
            print(f"Error loading sample {idx} ({sample['event_path']}): {e}")
            # Return a dummy sample if there's an error
            return self._create_dummy_sample(label=sample['label'])

    def _create_dummy_sample(self, label=0):
        """Create a dummy sample for debugging"""
        return {
            'events': torch.randn(10, 2, 260, 346),
            'audio': torch.randn(64, 64),
            'thermal': torch.randn(1, 224, 224),
            'label': torch.tensor(label, dtype=torch.long)
        }

def create_data_loaders(metadata_path, data_root, batch_size=8, validation_split=0.2):
    """Create training and validation data loaders"""
    # Create dataset
    dataset = MultiModalDataset(metadata_path, data_root)

    # Split into train and validation
    train_size = int((1 - validation_split) * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    return train_loader, val_loader

In [115]:
%%writefile utils/audio_processor.py
import librosa
import numpy as np
import torch

class AudioProcessor:
    def __init__(self, sample_rate=16000, n_fft=1024, hop_length=512, n_mels=64):
        self.sr = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels

    def load_audio(self, file_path):
        y, _ = librosa.load(file_path, sr=self.sr)
        return y

    def extract_features(self, audio):
        """Extract mel spectrogram features"""
        # Handle short audio clips
        if len(audio) < self.n_fft:
            audio = np.pad(audio, (0, self.n_fft - len(audio)))

        S = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels
        )
        log_S = librosa.power_to_db(S, ref=np.max)
        return torch.tensor(log_S, dtype=torch.float32)

    def augment_audio(self, audio, noise_level=0.005):
        """Add realistic noise augmentation"""
        noise = np.random.normal(0, noise_level, len(audio))
        return audio + noise

Overwriting utils/audio_processor.py


In [42]:
import sys
sys.path.append('/content/drive/MyDrive/ADAS_Project')
sys.path.append('/content/utils')

In [114]:
%%writefile utils/thermal_processor.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms

class ThermalProcessor:
    def __init__(self, input_size=(224, 224), normalize=True):
        """
        Initialize thermal image processor

        Args:
            input_size (tuple): Target size for resizing images (height, width)
            normalize (bool): Whether to normalize thermal values
        """
        self.input_size = input_size
        self.normalize = normalize
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.ToTensor(),
        ])

    def preprocess(self, thermal_image):
        """
        Preprocess a single thermal image

        Args:
            thermal_image (numpy.ndarray): Input thermal image (single channel)

        Returns:
            torch.Tensor: Preprocessed thermal image tensor
        """
        # Convert to float32 if needed
        if thermal_image.dtype != np.float32:
            thermal_image = thermal_image.astype(np.float32)

        # Normalize to [0, 1] if not already
        if thermal_image.max() > 1.0:
            thermal_image = thermal_image / thermal_image.max()

        # Apply transformations
        tensor = self.transform(thermal_image)

        # Add batch dimension if needed
        if len(tensor.shape) == 3:
            tensor = tensor.unsqueeze(0)

        return tensor

    def add_thermal_noise(self, thermal_image, noise_level=0.1):
        """Add realistic thermal noise to the image"""
        noise = np.random.normal(0, noise_level, thermal_image.shape).astype(np.float32)
        return np.clip(thermal_image + noise, 0, 1)

    def adjust_thermal_contrast(self, thermal_image, alpha=1.0, beta=0.0):
        """Adjust contrast of thermal image"""
        return np.clip(alpha * thermal_image + beta, 0, 1)


class ThermalEncoder(nn.Module):
    def __init__(self, input_channels=1, base_channels=32, output_dim=128):
        """
        CNN-based thermal feature extractor
        """
        super().__init__()
        self.features = nn.Sequential(
            # Initial conv block
            nn.Conv2d(input_channels, base_channels, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # Residual blocks
            self._make_residual_block(base_channels, base_channels * 2, 2),
            self._make_residual_block(base_channels * 2, base_channels * 4, 2),
            self._make_residual_block(base_channels * 4, base_channels * 8, 2),

            # Final pooling
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),

            # Projection head
            nn.Linear(base_channels * 8, output_dim)
        )

    def _make_residual_block(self, in_channels, out_channels, stride):
        """Create a residual block with skip connection"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.features(x)

Overwriting utils/thermal_processor.py
