In [1]:
# Environment Setup and Dependencies
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
from torchvision import transforms
from torchvision import models
import open3d as o3d
import sofar
import sys, os, random
sys.path.append('/autofs/thau04b/hghallab/comp/Final model')
from utils_d import SonicomDatabase, DataConfiguration
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import wandb
import random


from HRTFNet_onefreq import MultiViewHRTFPredictionModel

  warn(


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# API_KEY = "WANDB API"
# wandb.login(key=API_KEY)

In [3]:
# Implement CFG class to handle configurations
class CFG:
    seed = 42
    sonicom_root = '/autofs/thau04b/hghallab/comp/Huawei/TechArena20241016/data/'
    batch_size = 4
    image_size = [1024, 1024]
    downsampled_size = [256, 256]
    task = 0
    num_workers = 4
    num_images = 19
    num_epochs = 500000000000
    learning_rate = 1e-3
    output_size = 256  # Adjust based on HRTF size
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model_save_path = './best_onefreq_centered_model_unfrozen.pth'
    # Wandb configurations
    log_to_wandb = True
    project_name = '2D_to_HRTF'
    run_name = 'onefreq'
    
    precompute_sht = False
    sht_order = 7
    num_coeffs = (sht_order + 1) ** 2
    
    comments: """MultiViewHRTFPredictionModel
    A neural network model for predicting single-frequency Head-Related Transfer Functions (HRTFs) from multi-view point clouds of ears.
    Components:
        - PointNetFeatureExtractor: Extracts features from point clouds.
            Input: [batch_size, 3, num_points]
            Output: [batch_size, 1024, 1]
        - ViewTransformer: Pools features from multiple views using transformer attention mechanism.
            Input: [batch_size, num_views, 1024]
            Output: [batch_size, 1024]
        - Frequency Embedding: Encodes frequency information.
            Input: [batch_size] (frequency indices)
            Output: [batch_size, 16]
        - MLP Regressor: Generates HRTF predictions for a single frequency.
            Input: [batch_size, 2048 + 16] (combined ear features + frequency embedding)
            Output: [batch_size, 793 * 2 * 2] (positions * ears * mag/phase)
    Input:
        - point_clouds: Tensor [batch_size, 2, num_views, num_points, 3]
        - frequency: Tensor [batch_size] frequency indices
    Output:
        - HRTF predictions: [batch_size, 793, 2, 2] (positions, ears, magnitude/phase)
        
    The model processes both ears simultaneously and predicts magnitude and phase 
    for a specific frequency across all positions. The design allows efficient 
    feature extraction by processing point cloud features once and reusing them
    across frequency predictions."""

cfg = CFG()

In [4]:
wandb.init(
    project=cfg.project_name,
    config=cfg,
    name=cfg.run_name
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhouda-ghallab[0m ([33mhouda222[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

> SEEDING DONE


In [6]:
# Data Pipeline Implementation

# Create DataConfiguration
config = DataConfiguration(
    use_2d=True,
    use_3d=True,
    use_hrtf=True,
    use_3d_head=False,
    precompute_sht=cfg.precompute_sht
)

# Initialize dataset
sd = SonicomDatabase(root_dir=cfg.sonicom_root, config=config, sht_order=cfg.sht_order, training_data=True, task_id=cfg.task)

# Split the dataset into train and val sets
train_size = int(1 * len(sd))
val_size = len(sd) - train_size
train_dataset, val_dataset = random_split(sd, [train_size, val_size])

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

Number of subjects:  90
Total number of images:  3420


Computing SHT coefficients: 100%|██████████| 90/90 [00:03<00:00, 29.59it/s]

SHT coefficients loaded.





In [7]:
# Load the test set using the dataset class
test_sd = SonicomDatabase(root_dir=cfg.sonicom_root, config=config, sht_order=cfg.sht_order, training_data=False, task_id=0)
test_dataloader = DataLoader(test_sd, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

Number of subjects:  10
Total number of images:  380


Computing SHT coefficients:   0%|          | 0/10 [00:00<?, ?it/s]

Computing SHT coefficients: 100%|██████████| 10/10 [00:00<00:00, 31.36it/s]

SHT coefficients loaded.





In [8]:
_, _, pt, hrtf, _, _ = next(iter(test_dataloader))

In [9]:
pt.shape

torch.Size([4, 2, 19, 30000, 3])

In [10]:
# for patient in pt:
#     for ear in patient:
#         for view in ear:
#             print(view.shape)
#             pcd = o3d.geometry.PointCloud()
#             pcd.points = o3d.utility.Vector3dVector(view.cpu().numpy())
#             o3d.io.write_point_cloud(f"patient_.ply", pcd)
#     break

In [11]:
def downsize_image(imgs, final_shape=(512, 512)):
    """
    Downsizes the input image tensor to the specified final shape.
    
    Parameters:
    imgs (torch.Tensor): Input image tensor with shape [B, E, I, C, H, W]
    final_shape (tuple): Target shape (height, width)
    
    Returns:
    torch.Tensor: Resized tensor with shape [B, I, 2, final_H, final_W]
    """
    B, E, I, C, H, W = imgs.shape
    imgs = imgs.view(-1, H, W)  # Flatten to [B*E*I*C, H, W]
    
    # Add channel dimension for interpolate
    imgs = imgs.unsqueeze(1)  # [B*E*I*C, 1, H, W]
    
    # Resize
    imgs = F.interpolate(imgs, size=final_shape, mode='bilinear', align_corners=False)
    
    # Reshape back
    imgs = imgs.squeeze(1)  # Remove the temporary channel dim
    imgs = imgs.view(B, E, I, C, final_shape[0], final_shape[1])
    
    # Prepare final format: [B, I, 2, H, W]
    imgs = imgs.squeeze(1)  # Remove ear dimension since we process one ear at a time
    imgs = imgs.squeeze(2)  # Remove the extra channel dimension
    
    return imgs

def normalize_image(image):
    return image.float() / 255.0


def concatenate_images_and_depth_maps(images, depth_maps):
    return torch.cat([images, depth_maps], dim=3)  # Concatenate along channel dimension

def preprocess_data(images, depth_maps):
    images = normalize_image(images)
    inputs = torch.cat([images, depth_maps], dim=3) 
    return inputs


In [12]:
def transform_points(point_data):
    points = point_data[:, :, :3]  # Shape: [batch_size, num_points, 3]
    return points.permute(0, 2, 1).float()

In [13]:
def complex_hrtf(hrtf_mag_phase):
    # Split the real and imaginary parts
    hrtf_magnitude, hrtf_phase = torch.chunk(hrtf_mag_phase, 2, dim=-1)
    # Combine them to form the complex hrtf
    hrtf = hrtf_magnitude * torch.exp(1j * hrtf_phase)
    return hrtf

def hrtf2magnitudephase(hrtf):
    hrtf_magnitude = torch.abs(hrtf)
    hrtf_phase = torch.angle(hrtf)
    return torch.cat([hrtf_magnitude, hrtf_phase], dim=-1).float()

In [14]:
def reconstruct_hrtf_from_magnitude(magnitude):
    magnitude = magnitude.to(dtype=torch.float32)
    # Add epsilon to avoid log(0) and clamp negative values
    epsilon = 1e-8
    magnitude = magnitude.clamp(min=epsilon)
    log_magnitude = torch.log(magnitude)

    if torch.isnan(log_magnitude).any():
        print("NaNs detected after logarithm")
        return None  # Exit or handle appropriately

    # Compute FFT of log magnitude
    fft_log_magnitude = torch.fft.fft(log_magnitude)
    N = log_magnitude.shape[-1]
    h = torch.zeros(N, device=magnitude.device, dtype=fft_log_magnitude.dtype)
    if N % 2 == 0:
        h[0] = h[N // 2] = 1
        h[1:N // 2] = 2
    else:
        h[0] = 1
        h[1:(N + 1) // 2] = 2
    fft_log_magnitude *= h
    analytic_signal = torch.fft.ifft(fft_log_magnitude)
    phase = -analytic_signal.imag
    complex_hrtf = magnitude * torch.exp(1j * phase)

    return complex_hrtf

def reconstruct_complex_hrtf(hrtf_magnitude_phase):
    hrtf_magnitude, hrtf_phase = torch.chunk(hrtf_magnitude_phase, 2, dim=-1)
    hrtf = hrtf_magnitude * torch.exp(1j * hrtf_phase)
    return hrtf

In [15]:
def unconcatenate_point_clouds(combined_point_cloud):
    # Assuming the combined point cloud is concatenated along axis 1
    num_points = combined_point_cloud.shape[1] // 2
    left_point_cloud = combined_point_cloud[:, :num_points]
    right_point_cloud = combined_point_cloud[:, num_points:]
    return transform_points(left_point_cloud).to(cfg.device), transform_points(right_point_cloud).to(cfg.device)

In [16]:
# Verify that data was loaded successfully by visualizing some samples
%matplotlib inline

def visualize_data(imgs=None, depth_imgs=None, pt=None, true_hrtf=None, predicted_hrtf=None, plot_whole_batch=False, sampling_rate=48000):

    if imgs is not None:
        # imgs shape: [batch_size, num_ears, num_images_per_ear, channels, H, W]
        batch_size, num_ears, num_images_per_ear, channels, H, W = imgs.shape
        indices = range(batch_size) if plot_whole_batch else [0]
        for idx in indices:
            for ear in range(num_ears):
                fig, axes = plt.subplots(1, num_images_per_ear, figsize=(15, 5))
                for img_idx in range(num_images_per_ear):
                    img = imgs[idx, ear, img_idx].cpu().numpy()
                    if channels == 1:
                        img = img[0]  # Extract the single channel
                    else:
                        img = np.transpose(img, (1, 2, 0))  # Convert to HxWxC
                    axes[img_idx].imshow(img, cmap='gray' if channels == 1 else None)
                    axes[img_idx].axis('off')
                plt.suptitle(f'Sample {idx+1}, Ear {ear+1}')
                plt.show()

    if depth_imgs is not None:
        # depth_imgs shape: [batch_size, num_ears, num_images_per_ear, channels, H, W]
        batch_size, num_ears, num_images_per_ear, channels, H, W = depth_imgs.shape
        indices = range(batch_size) if plot_whole_batch else [0]
        for idx in indices:
            for ear in range(num_ears):
                fig, axes = plt.subplots(1, num_images_per_ear, figsize=(15, 5))
                for img_idx in range(num_images_per_ear):
                    img = depth_imgs[idx, ear, img_idx].cpu().numpy()
                    if channels == 1:
                        img = img[0]  # Extract the single channel
                    else:
                        img = np.transpose(img, (1, 2, 0))  # Convert to HxWxC
                    axes[img_idx].imshow(img, cmap='plasma')
                    axes[img_idx].axis('off')
                plt.suptitle(f'Depth Sample {idx+1}, Ear {ear+1}')
                plt.show()
    
    if pt is not None:
        # Split the point clouds for left and right ears
        pt1, pt2 = unconcatenate_point_clouds(pt)
        
        # Create a figure with two subplots side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
        
        # Plot left ear
        point_cloud1 = pt1[0].permute(1, 0).detach().cpu().numpy()
        indices1 = np.random.choice(point_cloud1.shape[0], size=10000, replace=False)
        sampled_points1 = point_cloud1[indices1]
        ax1.scatter(sampled_points1[:, 0], sampled_points1[:, 1], sampled_points1[:, 2], s=1)
        ax1.view_init(elev=90, azim=180)
        ax1.set_title('Left Ear')
        
        # Plot right ear
        point_cloud2 = pt2[0].permute(1, 0).detach().cpu().numpy()
        indices2 = np.random.choice(point_cloud2.shape[0], size=10000, replace=False)
        sampled_points2 = point_cloud2[indices2]
        ax2.scatter(sampled_points2[:, 0], sampled_points2[:, 1], sampled_points2[:, 2], s=1)
        ax2.view_init(elev=90, azim=0)
        ax2.set_title('Right Ear')
        
        plt.tight_layout()
        plt.show()
    
    # HRTF Visualization
    if true_hrtf is not None or predicted_hrtf is not None:
        batch_size = (
            true_hrtf.shape[0] if true_hrtf is not None else predicted_hrtf.shape[0]
        )
        indices = range(batch_size) if plot_whole_batch else [0]

        # Iterate over batch indices
        for idx in indices:
            # Determine HRTF size
            if true_hrtf is not None:
                hrtf_size = true_hrtf.shape[-1] // 2
            else:
                hrtf_size = predicted_hrtf.shape[-1] // 2

            # Frequency axis
            freq_axis = np.linspace(0, sampling_rate / 2, hrtf_size)

            # Prepare figures for each ear
            for ear_index in range(2):  # 0 for left ear, 1 for right ear
                fig, ax = plt.subplots(figsize=(10, 6))

                if true_hrtf is not None:
                    hrtf_true = true_hrtf[idx].cpu().numpy()  # Shape: [positions, ears, total_bins]
                    # Split real and imaginary parts
                    hrtf_true_real = hrtf_true[:, ear_index, :hrtf_size]
                    hrtf_true_imag = hrtf_true[:, ear_index, hrtf_size:]
                    # Construct complex HRTF
                    hrtf_true_complex = hrtf_true_real + 1j * hrtf_true_imag
                    # Compute magnitude in dB
                    magnitude_true = 20 * np.log10(np.abs(hrtf_true_complex) + 1e-6)
                    # Average over positions
                    magnitude_true_mean = np.mean(magnitude_true, axis=0)  # Shape: [hrtf_size]

                    # Plot true HRTF
                    ax.plot(freq_axis, magnitude_true_mean, label='True HRTF')

                if predicted_hrtf is not None:
                    hrtf_pred = predicted_hrtf[idx].cpu().numpy()
                    hrtf_pred_real = hrtf_pred[:, ear_index, :hrtf_size]
                    hrtf_pred_imag = hrtf_pred[:, ear_index, hrtf_size:]
                    hrtf_pred_complex = hrtf_pred_real + 1j * hrtf_pred_imag
                    magnitude_pred = 20 * np.log10(np.abs(hrtf_pred_complex) + 1e-6)
                    # Average over positions
                    magnitude_pred_mean = np.mean(magnitude_pred, axis=0)

                    # Plot predicted HRTF
                    ax.plot(
                        freq_axis,
                        magnitude_pred_mean,
                        label='Predicted HRTF',
                        linestyle='--',
                    )

                ax.set_title(f'Sample {idx + 1} Ear {ear_index + 1} HRTF Magnitude Response')
                ax.set_xlabel('Frequency (Hz)')
                ax.set_ylabel('Magnitude (dB)')
                ax.legend()
                ax.grid(True)

                plt.tight_layout()
                plt.show()

In [17]:
# for i, (images, depth_maps, pt) in enumerate(train_dataloader):
#     print(f"Sample {i+1}:")
#     print(f"Images shape: {images.shape}")
#     print(f"Depth maps shape: {depth_maps.shape}")
#     print(f"Point cloud shape: {pt.shape}")
#     print("-" * 30)

In [18]:
# Initialize WandB
wandb.init(project=cfg.project_name, name=cfg.run_name, config=vars(cfg))

In [19]:
set_seed(cfg.seed)

> SEEDING DONE


In [20]:
model = MultiViewHRTFPredictionModel().to(cfg.device)



In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(model)
print(f'The model has {num_params:,} trainable parameters.')

num_params = count_parameters(model.feat)
print(f'The feature extractor has {num_params:,} trainable parameters.')

num_params = count_parameters(model.regression)
print(f'The regressor has {num_params:,} trainable parameters.')



The model has 16,680,884 trainable parameters.
The feature extractor has 807,744 trainable parameters.
The regressor has 5,365,860 trainable parameters.


In [22]:
import gc
def manage_memory():
    gc.collect()
    torch.cuda.empty_cache()

In [23]:
### Redefine MSD as a loss function
class MeanSpectralDistortion_loss(nn.Module):
    def __init__(self):
        super().__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.avg_hrir = sofar.read_sofa("/autofs/thau04b/hghallab/comp/Huawei/TechArena20241016/data/Average_HRTFs.sofa", verbose=False)
        self.source_positions = self.avg_hrir.SourcePosition
        
        # Convert indices and weights to tensors and register as buffers
        elevation_index = torch.tensor(self._get_elevation_index(), device=device)
        weights = torch.tensor(self._get_weights(), device=device)
        
        # Register as buffers so they move to GPU with the model
        self.register_buffer('elevation_index', elevation_index)
        self.register_buffer('weights', weights)

    def forward(self, hrtf_ground_truth: torch.Tensor, hrtf_predicted: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hrtf_ground_truth: [batch_size, positions, channels, frequencies]
            hrtf_predicted: [batch_size, positions, channels, frequencies]
        """
        # Select elevations for batch
        gt_selected = torch.index_select(hrtf_ground_truth, 1, self.elevation_index)
        pred_selected = torch.index_select(hrtf_predicted, 1, self.elevation_index)
        
        # Compute magnitude difference
        magnitude_diff = gt_selected.abs() - pred_selected.abs()
        
        # Apply weights (broadcasting over batch, positions, channels)
        weighted_diff = self.weights.view(1, 1, 1, -1) * magnitude_diff
        
        # Compute mean squared error
        weighted_error = (weighted_diff ** 2).mean()
        
        # Convert to dB
        return 10 * torch.log10(weighted_error + 1e-8)
    
    def _get_weights(self):
        """
        This function load the weights which are used when you calculate the spectral distortion/ baseline predictions
        weights were calculated based on the paper "Looking for a relevant similarity criterion fo HRTF clustering: a comparative study - Rozenn Nicol".

        Returns:
               normalized_weights: torch.tensor

        """
        # Generate a list of frequencies up to 24 kHz
        frequencies_Hz = torch.linspace(0, 24000, 129)  # 129 points between 0 Hz and 24 kHz
        frequencies_kHz = frequencies_Hz / 1000
        inv_cb = 1 / (25 + 75 * (1 + 1.4 * frequencies_kHz**2) ** 0.69)  # inverse of delta (critical bandwidth)
        a0 = sum(inv_cb)
        normalized_weights = inv_cb / a0
        return normalized_weights

    def _get_elevation_index(self):
        """
        Helper function to get elevation indexes.
        Args:
            You can change the elevation range as required. We will use the elvation range between -30 to 30
            Returns:
             all index for the elevation range"""
        # this function gives the index of the directions for which you need to evaluate your results.

        azimuths = self.source_positions[:, 0]
        elevations = self.source_positions[:, 1]

        # Define the elevation range
        elevation_min = -30
        elevation_max = 30
        # Find the indices for the specific elevation range
        elevation_indices = np.where((elevations >= elevation_min) & (elevations <= elevation_max))[0]

        # Ensure that elevation_indices is a NumPy array of integers
        return np.array(elevation_indices, dtype=int)

In [24]:
checkpoint ="/mnt/thau04b/hghallab/comp/Final model/best_onefreq_centered_model.pth"
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint, strict=True)
for param in model.feat.parameters():
    param.requires_grad = True

for param in model.parameters():
    param.requires_grad = True

  checkpoint = torch.load(checkpoint)


In [25]:
criterion = MeanSpectralDistortion_loss()
mse = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=cfg.learning_rate)

  weights = torch.tensor(self._get_weights(), device=device)


In [26]:
# Initialize lists to store losses
training_losses = []
training_metrics = []
validation_losses = []
validation_metrics = []

In [27]:
# # Initialize a tensor to accumulate the sum of HRTFs
# hrtf_sum_real = torch.zeros((793, 2, 129), device=cfg.device)
# hrtf_sum_imag = torch.zeros((793, 2, 129), device=cfg.device)
# num_examples = 0

# # Iterate through the train loader
# for _, _, pts, hrtf in tqdm.tqdm(train_dataloader, desc="Calculating Mean HRTF"):
#     hrtf = hrtf.to(cfg.device)
#     hrtf_sum_real += hrtf.real.sum(dim=0)
#     hrtf_sum_imag += hrtf.imag.sum(dim=0)
#     num_examples += hrtf.size(0)

# # Calculate the mean HRTF
# mean_hrtf_real = hrtf_sum_real / num_examples
# mean_hrtf_imag = hrtf_sum_imag / num_examples

# mean_hrtf = mean_hrtf_real + 1j * mean_hrtf_imag
# mean_hrtf = mean_hrtf.unsqueeze(0)  # Add a batch dimension

# # Save the mean HRTF to a file
# mean_hrtf_path = os.path.join(cfg.sonicom_root, 'mean_hrtf.pt')
# torch.save(mean_hrtf, mean_hrtf_path)

# print("Mean HRTF calculated.")

In [28]:
mean_hrtf_path = os.path.join(cfg.sonicom_root, 'mean_hrtf.pt')
mean_hrtf = torch.load(mean_hrtf_path).to(CFG.device)
print("Mean HRTF loaded.")

Mean HRTF loaded.


  mean_hrtf = torch.load(mean_hrtf_path).to(CFG.device)


In [29]:
def train_one_epoch(epoch, model, dataloader, optimizer, criterion, metric_criterion, device):
    model.train()
    running_loss = 0.0
    running_metric = 0.0
    for _, _, pts, hrtf, _, _ in tqdm.tqdm(dataloader, desc=f"Training Epoch {epoch+1}"):
        # Process depth maps
        pts = pts.to(device)  # [batch_size, num_ears, num_views, H, W]
        hrtf = hrtf.to(device) # [batch_size, 793, 2, 129*2]
        batch_size = hrtf.shape[0]
        num_freqs = hrtf.shape[-1]
        
        hrtf_magnitude_phase = hrtf2magnitudephase(hrtf)
        
        all_predictions = []
        optimizer.zero_grad()

        output_magnitude_phase = model(pts)
        
        output_complex = complex_hrtf(output_magnitude_phase)
        
        uncentered_output_complex = output_complex + mean_hrtf.expand(batch_size, -1, -1, -1)
        uncentered_output_mag_phase = hrtf2magnitudephase(uncentered_output_complex)
        
        loss = criterion(uncentered_output_complex, hrtf)
        metric_loss = metric_criterion(uncentered_output_mag_phase, hrtf_magnitude_phase)

        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        torch.cuda.empty_cache()
        
        
        running_loss += loss.item()
        running_metric += metric_loss.item()
        
    avg_loss = running_loss / len(dataloader)
    avg_metric = running_metric / len(dataloader)
    training_losses.append(avg_loss)
    training_metrics.append(avg_metric)
    
    if cfg.log_to_wandb:
        wandb.log({'Train Loss': avg_loss, 'Epoch': epoch+1})
    
    return avg_loss, avg_metric

def evaluate(epoch, model, dataloader, criterion, metric_criterion, device):
    model.eval()
    running_loss = 0.0  # Track losses for each option
    running_metric = 0.0 # Track metrics for each option
    
    with torch.no_grad():
        for _, _, pts, hrtf, _, _ in tqdm.tqdm(dataloader, desc="Validation"):
            pts = pts.to(device)  # [batch_size, num_ears, num_views, H, W]
            hrtf = hrtf.to(device) # [batch_size, 793, 2, 129*2]
            batch_size = hrtf.shape[0]
            num_freqs = hrtf.shape[-1]
            
            hrtf_magnitude_phase = hrtf2magnitudephase(hrtf)
            
            output_magnitude_phase = model(pts)
            
            
            output_complex = complex_hrtf(output_magnitude_phase)
            
            uncentered_output_complex = output_complex + mean_hrtf.expand(batch_size, -1, -1, -1)
            uncentered_output_mag_phase = hrtf2magnitudephase(uncentered_output_complex)
            
            loss = criterion(uncentered_output_complex, hrtf)
            metric_loss = metric_criterion(uncentered_output_mag_phase, hrtf_magnitude_phase)

            torch.cuda.empty_cache()
            running_loss += loss.item()
            running_metric += metric_loss.item()
            
    
    # Calculate average losses for each option
    num_batches = len(dataloader)
    avg_loss = running_loss / num_batches
    avg_metric = running_metric / num_batches
    
    validation_losses.append(avg_loss)
    validation_metrics.append(avg_metric)
    
    if cfg.log_to_wandb:
        wandb.log({
            'Validation Loss': avg_loss,
            'Validation Metric': avg_metric,
            'Epoch': epoch+1
        })
    
    return avg_loss, avg_metric

In [None]:
best_val_loss = float('inf')

for epoch in range(cfg.num_epochs):
    manage_memory()
    train_loss, train_metric = train_one_epoch(epoch, model, train_dataloader, optimizer, criterion, mse, cfg.device)
    manage_memory()
    val_loss, val_metric = evaluate(epoch, model, test_dataloader, criterion, mse, cfg.device)
    # manage_memory()
    # test_loss, test_metric = evaluate(epoch, model, test_dataloader, criterion, mse, cfg.device)
    print(f'Epoch [{epoch+1}/{cfg.num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    # Log train and validation loss to wandb
    if cfg.log_to_wandb:
        wandb.log({
            "train/loss_MSD": train_loss,
            "val/loss_MSD": val_loss,
            "train/MSE": train_metric,
            "val/MSE": val_metric,
            "learning_rate": optimizer.param_groups[0]['lr'],
            "epoch": epoch
        })
    # Save the model if validation loss has decreased
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), cfg.model_save_path)
        print(f"Model saved at epoch {epoch+1} with validation loss {best_val_loss:.4f}")
        # Save checkpoint to wandb
        if cfg.log_to_wandb:
            wandb.save(cfg.model_save_path)


Training Epoch 1: 100%|██████████| 23/23 [01:27<00:00,  3.80s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.79s/it]


Epoch [1/500000000000], Train Loss: -62.4593, Val Loss: -62.8875
Model saved at epoch 1 with validation loss -62.8875


Training Epoch 2: 100%|██████████| 23/23 [01:24<00:00,  3.68s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.78s/it]


Epoch [2/500000000000], Train Loss: -62.5828, Val Loss: -63.0615
Model saved at epoch 2 with validation loss -63.0615


Training Epoch 3: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.76s/it]


Epoch [3/500000000000], Train Loss: -62.5305, Val Loss: -63.0599


Training Epoch 4: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [4/500000000000], Train Loss: -62.4784, Val Loss: -63.0619
Model saved at epoch 4 with validation loss -63.0619


Training Epoch 5: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [5/500000000000], Train Loss: -62.5891, Val Loss: -63.0524


Training Epoch 6: 100%|██████████| 23/23 [01:24<00:00,  3.68s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.77s/it]


Epoch [6/500000000000], Train Loss: -62.6287, Val Loss: -63.0609


Training Epoch 7: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [7/500000000000], Train Loss: -62.5440, Val Loss: -63.0575


Training Epoch 8: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.76s/it]


Epoch [8/500000000000], Train Loss: -62.5604, Val Loss: -63.0568


Training Epoch 9: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [9/500000000000], Train Loss: -62.5636, Val Loss: -63.0592


Training Epoch 10: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [10/500000000000], Train Loss: -62.6208, Val Loss: -63.0588


Training Epoch 11: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [11/500000000000], Train Loss: -62.5827, Val Loss: -63.0635
Model saved at epoch 11 with validation loss -63.0635


Training Epoch 12: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [12/500000000000], Train Loss: -62.7061, Val Loss: -63.0579


Training Epoch 13: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [13/500000000000], Train Loss: -62.6402, Val Loss: -63.0562


Training Epoch 14: 100%|██████████| 23/23 [01:24<00:00,  3.68s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.78s/it]


Epoch [14/500000000000], Train Loss: -62.6489, Val Loss: -63.0618


Training Epoch 15: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [15/500000000000], Train Loss: -62.6648, Val Loss: -63.0598


Training Epoch 16: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.75s/it]


Epoch [16/500000000000], Train Loss: -62.6368, Val Loss: -63.0629


Training Epoch 17: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [17/500000000000], Train Loss: -62.5885, Val Loss: -63.0557


Training Epoch 18: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [18/500000000000], Train Loss: -62.6596, Val Loss: -63.0600


Training Epoch 19: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [19/500000000000], Train Loss: -62.6676, Val Loss: -63.0600


Training Epoch 20: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [20/500000000000], Train Loss: -62.6701, Val Loss: -63.0574


Training Epoch 21: 100%|██████████| 23/23 [01:24<00:00,  3.68s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.75s/it]


Epoch [21/500000000000], Train Loss: -62.6404, Val Loss: -63.0605


Training Epoch 22: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [22/500000000000], Train Loss: -62.6818, Val Loss: -63.0600


Training Epoch 23: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.77s/it]


Epoch [23/500000000000], Train Loss: -62.6075, Val Loss: -63.0601


Training Epoch 24: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [24/500000000000], Train Loss: -62.6202, Val Loss: -63.0618


Training Epoch 25: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [25/500000000000], Train Loss: -62.6264, Val Loss: -63.0593


Training Epoch 26: 100%|██████████| 23/23 [01:24<00:00,  3.67s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [26/500000000000], Train Loss: -62.5711, Val Loss: -63.0608


Training Epoch 27: 100%|██████████| 23/23 [01:24<00:00,  3.69s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [27/500000000000], Train Loss: -62.6212, Val Loss: -63.0579


Training Epoch 28: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [28/500000000000], Train Loss: -62.6226, Val Loss: -63.0570


Training Epoch 29: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [29/500000000000], Train Loss: -62.6431, Val Loss: -63.0607


Training Epoch 30: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [30/500000000000], Train Loss: -62.6352, Val Loss: -63.0572


Training Epoch 31: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [31/500000000000], Train Loss: -62.6102, Val Loss: -63.0630


Training Epoch 32: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [32/500000000000], Train Loss: -62.6927, Val Loss: -63.0597


Training Epoch 33: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [33/500000000000], Train Loss: -62.5702, Val Loss: -63.0618


Training Epoch 34: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [34/500000000000], Train Loss: -62.6826, Val Loss: -63.0599


Training Epoch 35: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.75s/it]


Epoch [35/500000000000], Train Loss: -62.6313, Val Loss: -63.0580


Training Epoch 36: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.72s/it]


Epoch [36/500000000000], Train Loss: -62.6374, Val Loss: -63.0631


Training Epoch 37: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.73s/it]


Epoch [37/500000000000], Train Loss: -62.6294, Val Loss: -63.0571


Training Epoch 38: 100%|██████████| 23/23 [01:23<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [38/500000000000], Train Loss: -62.5476, Val Loss: -63.0610


Training Epoch 39: 100%|██████████| 23/23 [01:24<00:00,  3.65s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [39/500000000000], Train Loss: -62.6124, Val Loss: -63.0605


Training Epoch 40: 100%|██████████| 23/23 [01:25<00:00,  3.70s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.77s/it]


Epoch [40/500000000000], Train Loss: -62.6386, Val Loss: -63.0602


Training Epoch 41: 100%|██████████| 23/23 [01:25<00:00,  3.71s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [41/500000000000], Train Loss: -62.6492, Val Loss: -63.0618


Training Epoch 42: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.89s/it]


Epoch [42/500000000000], Train Loss: -62.7050, Val Loss: -63.0592


Training Epoch 43: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [43/500000000000], Train Loss: -62.6642, Val Loss: -63.0587


Training Epoch 44: 100%|██████████| 23/23 [01:23<00:00,  3.64s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.74s/it]


Epoch [44/500000000000], Train Loss: -62.5821, Val Loss: -63.0584


Training Epoch 45: 100%|██████████| 23/23 [01:24<00:00,  3.66s/it]
Validation: 100%|██████████| 3/3 [00:11<00:00,  3.77s/it]


Epoch [45/500000000000], Train Loss: -62.6029, Val Loss: -63.0621


Training Epoch 46:  70%|██████▉   | 16/23 [00:59<00:25,  3.58s/it]

In [None]:

Epoch [58/500000000000], Train Loss: -62.5544, Val Loss: -62.8132
Train Loss: -62.5104, Val Loss: -62.8134

In [None]:
Epoch [10/500000000000], Train Loss: -62.5241, Val Loss: -62.8182
Model saved at epoch 10 with validation loss -62.8182