In [None]:
import sys
sys.path.append('..')
import torch
from src.models.hybrid_autoencoder import RadarHybridAutoencoder
from src.data.dataset import RadarDataset
from src.utils.training_validation import train_autoencoder, process_and_save_results
from src.utils.evaluation import evaluate_reconstruction
from src.models.model_utils import get_model_info
import mat73
import numpy as np
from torch.utils.data import DataLoader, random_split

In [None]:
# 2. Load and Prepare Data
# Load training data
data_path = '../data/raw/dataset_2t1.mat'
data = mat73.loadmat(data_path)
radar_cube = data['radar_cube']
clean_data = data['clean_signals']

print(f"Radar cube shape: {radar_cube.shape}")
print(f"Clean data shape: {clean_data.shape}")

In [None]:
# 3. Prepare DataLoader
# Create dataset and split
full_dataset = RadarDataset(radar_cube, clean_data)
val_size = int(len(full_dataset) * 0.2)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create dataloaders
batch_size = 4   # small batch size considering cpu memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# 4. Initialize and Analyze Model
model = RadarHybridAutoencoder()
model_info = get_model_info(model)
print("\nModel Architecture Information:")
print(f"Model Type: {model_info['model_type']}")
print(f"Total Parameters: {model_info['total_params']:,}")
print(f"Trainable Parameters: {model_info['trainable_params']:,}")

In [None]:
# 5. Train Model
trained_model, history, memory_usage, training_time = train_autoencoder(
    model, 
    train_loader, 
    val_loader, 
    num_epochs=1
)

In [None]:
# 6. Evaluate Model
# Evaluate on validation set
metrics = evaluate_reconstruction(trained_model, radar_cube, clean_data, frame_idx=0)
print("\nValidation Metrics:")
print(f"MSE: {metrics['mse']:.2e}")
print(f"PSNR: {metrics['psnr']:.2f} dB")
print(f"SIR Improvement: {metrics['sir_improvement']:.2f} dB")