## Training Demonstration for WiMAE and ContraWiMAE

- This notebook demonstrates how to train both WiMAE and ContraWiMAE models
using real wireless channel data from the pretrain folder.

#### Imports and Setup

In [1]:
import sys
import torch
import numpy as np
import yaml
from pathlib import Path

# Add parent directory to path for imports
try:
    # For Python scripts
    sys.path.append(str(Path(__file__).parent.parent))
except NameError:
    # For Jupyter notebooks
    sys.path.append(str(Path().cwd().parent))

# WiMAE imports
from wimae.training.train_wimae import WiMAETrainer
# from wimae.training.train_contramae import ContraWiMAETrainer
# from wimae.models.base import WiMAE
# from wimae.models.contramae import ContraWiMAE

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

All imports successful!
PyTorch version: 2.5.1
CUDA available: True


#### Data Overview

In [2]:
data_path = "../data/pretrain"
npz_files = list(Path(data_path).glob("*.npz"))

print(f"Available datasets: {len(npz_files)} cities")
for file in sorted(npz_files):
    file_size = file.stat().st_size / (1024*1024)  # MB
    print(f"  • {file.name}: {file_size:.1f} MB")

# Load one file to check data structure
with np.load(npz_files[0]) as sample_data:
    print(f"\nSample data structure from {npz_files[0].name}:")
    for key, value in sample_data.items():
        if hasattr(value, 'shape'):
            print(f"  • {key}: {value.shape} ({value.dtype})")
        else:
            print(f"  • {key}: {value}")

Available datasets: 12 cities
  • city_0_newyork_channels.npz: 10.0 MB
  • city_10_austin_channels.npz: 14.5 MB
  • city_13_columbus_channels.npz: 11.2 MB
  • city_17_seattle_channels.npz: 11.5 MB
  • city_1_losangeles_channels.npz: 5.8 MB
  • city_2_chicago_channels.npz: 2.2 MB
  • city_3_houston_channels.npz: 20.1 MB
  • city_4_phoenix_channels.npz: 21.1 MB
  • city_5_philadelphia_channels.npz: 4.9 MB
  • city_6_miami_channels.npz: 13.1 MB
  • city_8_dallas_channels.npz: 19.6 MB
  • city_9_sanfrancisco_channels.npz: 13.0 MB

Sample data structure from city_0_newyork_channels.npz:
  • channels: (1283, 1, 32, 32) (complex64)
  • name: () (<U14)
  • n_rows: (2,) (int64)
  • n_per_row: () (int64)
  • active_bs: (1,) (int64)
  • n_ant_bs: () (int64)
  • n_ant_ue: () (int64)
  • n_subcarriers: () (int64)
  • scs: () (float64)
  • data_folder: () (<U16)
  • bs_rotation: (3,) (int64)
  • enable_bs2bs: () (bool)
  • num_paths: () (int64)


#### Configuration Setup

In [3]:
config_path = "../configs/default_training.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully!")
print(f"Model type: {config['model']['type']}")
print(f"Encoder dimensions: {config['model']['encoder_dim']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Training epochs: {config['training']['epochs']}")
print(f"Learning rate: {config['training']['optimizer']['lr']}")

# Adjust config for demo (shorter training)
config['training']['epochs'] = 1
config['training']['batch_size'] = 768
config['data']['data_dir'] = "../data/pretrain"
config['data']['debug_size'] = config['training']['batch_size'] * 5
config['model']['mask_ratio'] = 0.2
config['logging']['exp_name'] = "demo_exp"
print(f"\nAdjusted for demo:")
print(f"Epochs: {config['training']['epochs']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Data directory: {config['data']['data_dir']}")
print(f"Mask ratio: {config['model']['mask_ratio']}")
print(f"Experiment folder name: {config['logging']['exp_name']}")
print(f"Debug data size: {config['data']['debug_size']}")

Configuration loaded successfully!
Model type: wimae
Encoder dimensions: 64
Batch size: 64
Training epochs: 3000
Learning rate: 0.0003

Adjusted for demo:
Epochs: 1
Batch size: 768
Data directory: ../data/pretrain
Mask ratio: 0.2
Experiment folder name: demo_exp
Debug data size: 3840


#### WiMAE Training Setup

In [4]:
print("Setting up WiMAE training...")

# Create WiMAE trainer (model will be created during initialization)
wimae_trainer = WiMAETrainer(config=config)

print("WiMAE trainer initialized")

# Get model information
wimae_info = wimae_trainer.model.get_model_info()
print(f"WiMAE model created:")
for key, value in wimae_info.items():
    print(f"  • {key}: {value}")


Setting up WiMAE training...
WiMAE trainer initialized
WiMAE model created:
  • model_type: WiMAE
  • patch_size: (1, 16)
  • encoder_dim: 64
  • encoder_layers: 12
  • encoder_nhead: 16
  • decoder_layers: 4
  • decoder_nhead: 8
  • mask_ratio: 0.2
  • total_parameters: 554128
  • trainable_parameters: 554128


####  WiMAE Training Execution

In [None]:
print("Starting WiMAE training...")

try:
    # Start training (dataloaders will be set up automatically)
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")
    print("This is expected in a demo - check your data paths and configuration")

Starting WiMAE training...
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
Loading file 3/12: city_10_austin_channels.npz
Loading file 4/12: city_9_sanfrancisco_channels.npz
Loading file 5/12: city_6_miami_channels.npz


#### Checkpoints and Model Loading

In [None]:
print("Checkpoint Management")
print("=" * 30)

checkpoint_path = Path(f"./{config['logging']['log_dir']}") / f"{config['model']['type']}_{config['logging']['exp_name']}"

# Check available checkpoints
wimae_checkpoints = list(checkpoint_path.glob("*.pt")) if checkpoint_path.exists() else []

print(f"WiMAE checkpoints: {len(wimae_checkpoints)}")
for ckpt in wimae_checkpoints:
    print(f"  • {ckpt.name}")

# Example of loading a checkpoint (if available)
if wimae_checkpoints:
    print(f"\nExample: Loading WiMAE checkpoint...")
    try:
        checkpoint = torch.load(wimae_checkpoints[0], map_location='cpu')
        print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
        print(f"Validation loss: {checkpoint.get('best_val_loss', 'unknown')}")
    except Exception as e:
        print(f"Failed to load checkpoint: {e}")