## Training Demonstration for WiMAE and ContraWiMAE

- This notebook demonstrates how to train both WiMAE and ContraWiMAE models
using real wireless channel data from the pretrain folder.

#### Imports and Setup

In [1]:
import sys
import torch
import numpy as np
import yaml
from pathlib import Path

# Add parent directory to path for imports
try:
    # For Python scripts
    sys.path.append(str(Path(__file__).parent.parent))
except NameError:
    # For Jupyter notebooks
    sys.path.append(str(Path().cwd().parent))

# WiMAE imports
from wimae.training.train_wimae import WiMAETrainer
# from wimae.training.train_contramae import ContraWiMAETrainer
# from wimae.models.base import WiMAE
# from wimae.models.contramae import ContraWiMAE

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

All imports successful!
PyTorch version: 2.5.1
CUDA available: True


#### Data Overview

In [2]:
data_path = "../data/pretrain"
npz_files = list(Path(data_path).glob("*.npz"))

print(f"Available datasets: {len(npz_files)} cities")
for file in sorted(npz_files):
    file_size = file.stat().st_size / (1024*1024)  # MB
    print(f"  • {file.name}: {file_size:.1f} MB")

# Load one file to check data structure
with np.load(npz_files[0]) as sample_data:
    print(f"\nSample data structure from {npz_files[0].name}:")
    for key, value in sample_data.items():
        if hasattr(value, 'shape'):
            print(f"  • {key}: {value.shape} ({value.dtype})")
        else:
            print(f"  • {key}: {value}")

Available datasets: 12 cities
  • city_0_newyork_channels.npz: 10.0 MB
  • city_10_austin_channels.npz: 14.5 MB
  • city_13_columbus_channels.npz: 11.2 MB
  • city_17_seattle_channels.npz: 11.5 MB
  • city_1_losangeles_channels.npz: 5.8 MB
  • city_2_chicago_channels.npz: 2.2 MB
  • city_3_houston_channels.npz: 20.1 MB
  • city_4_phoenix_channels.npz: 21.1 MB
  • city_5_philadelphia_channels.npz: 4.9 MB
  • city_6_miami_channels.npz: 13.1 MB
  • city_8_dallas_channels.npz: 19.6 MB
  • city_9_sanfrancisco_channels.npz: 13.0 MB

Sample data structure from city_0_newyork_channels.npz:
  • channels: (1283, 1, 32, 32) (complex64)
  • name: () (<U14)
  • n_rows: (2,) (int64)
  • n_per_row: () (int64)
  • active_bs: (1,) (int64)
  • n_ant_bs: () (int64)
  • n_ant_ue: () (int64)
  • n_subcarriers: () (int64)
  • scs: () (float64)
  • data_folder: () (<U16)
  • bs_rotation: (3,) (int64)
  • enable_bs2bs: () (bool)
  • num_paths: () (int64)


#### Configuration Setup

In [6]:
config_path = "../configs/default_training.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully!")
print(f"Model type: {config['model']['type']}")
print(f"Encoder dimensions: {config['model']['encoder_dim']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Training epochs: {config['training']['epochs']}")
print(f"Learning rate: {config['training']['optimizer']['lr']}")

# Adjust config for demo (shorter training)
config['training']['epochs'] = 10
config['training']['batch_size'] = 768
config['data']['data_dir'] = "../data/pretrain"
print(f"\nAdjusted for demo:")
print(f"Epochs: {config['training']['epochs']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Data directory: {config['data']['data_dir']}")

Configuration loaded successfully!
Model type: wimae
Encoder dimensions: 64
Batch size: 64
Training epochs: 3000
Learning rate: 0.0003

Adjusted for demo:
Epochs: 10
Batch size: 768
Data directory: ../data/pretrain


#### WiMAE Training Setup

In [7]:
print("Setting up WiMAE training...")

# Create WiMAE trainer (model will be created during initialization)
wimae_trainer = WiMAETrainer(config=config)

print("WiMAE trainer initialized")

# Get model information
wimae_info = wimae_trainer.model.get_model_info()
print(f"WiMAE model created:")
for key, value in wimae_info.items():
    print(f"  • {key}: {value}")


Setting up WiMAE training...
WiMAE trainer initialized
WiMAE model created:
  • model_type: WiMAE
  • patch_size: (1, 16)
  • encoder_dim: 64
  • encoder_layers: 12
  • encoder_nhead: 16
  • decoder_layers: 4
  • decoder_nhead: 8
  • mask_ratio: 0.6
  • total_parameters: 554128
  • trainable_parameters: 554128


####  WiMAE Training Execution

In [8]:
print("Starting WiMAE training...")

try:
    # Start training (dataloaders will be set up automatically)
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")
    print("This is expected in a demo - check your data paths and configuration")

Starting WiMAE training...
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
Loading file 3/12: city_10_austin_channels.npz
Loading file 4/12: city_9_sanfrancisco_channels.npz
Loading file 5/12: city_6_miami_channels.npz
Loading file 6/12: city_13_columbus_channels.npz
Loading file 7/12: city_2_chicago_channels.npz
Loading file 8/12: city_5_philadelphia_channels.npz
Loading file 9/12: city_17_seattle_channels.npz
Loading file 10/12: city_3_houston_channels.npz
Loading file 11/12: city_8_dallas_channels.npz
Loading file 12/12: city_4_phoenix_channels.npz
Successfully loaded all 18824 samples
Computing statistics from training dataset...
Calculated statistics: {'real_mean': 0.046293627470731735, 'real_std': 30.634143829345703, 'imag_mean': 8.498757961206138e-05, 'imag_std': 30.665428161621094}
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city

Training Epoch 0: 100%|██████████| 20/20 [00:09<00:00,  2.18it/s, loss=1.0563, avg_loss=1.0420]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.37it/s]


Epoch 0:
  train_loss: 1.0420
  val_masked_loss: 1.0257
  val_full_loss: 1.0524
  val_loss: 1.0257


Training Epoch 1: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s, loss=0.9924, avg_loss=1.0091]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.54it/s]


Epoch 1:
  train_loss: 1.0091
  val_masked_loss: 1.0233
  val_full_loss: 1.0441
  val_loss: 1.0233


Training Epoch 2: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s, loss=1.0394, avg_loss=1.0051]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.53it/s]


Epoch 2:
  train_loss: 1.0051
  val_masked_loss: 1.0226
  val_full_loss: 1.0415
  val_loss: 1.0226


Training Epoch 3: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s, loss=0.7471, avg_loss=1.0001]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.54it/s]


Epoch 3:
  train_loss: 1.0001
  val_masked_loss: 1.0265
  val_full_loss: 1.0408
  val_loss: 1.0265


Training Epoch 4: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s, loss=1.1554, avg_loss=1.0061]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.55it/s]


Epoch 4:
  train_loss: 1.0061
  val_masked_loss: 1.0225
  val_full_loss: 1.0408
  val_loss: 1.0225


Training Epoch 5: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s, loss=0.9776, avg_loss=1.0007]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.55it/s]


Epoch 5:
  train_loss: 1.0007
  val_masked_loss: 1.0244
  val_full_loss: 1.0425
  val_loss: 1.0244


Training Epoch 6: 100%|██████████| 20/20 [00:08<00:00,  2.25it/s, loss=0.9933, avg_loss=0.9997]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.57it/s]


Epoch 6:
  train_loss: 0.9997
  val_masked_loss: 1.0248
  val_full_loss: 1.0447
  val_loss: 1.0248


Training Epoch 7: 100%|██████████| 20/20 [00:08<00:00,  2.25it/s, loss=1.0601, avg_loss=1.0009]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.53it/s]


Epoch 7:
  train_loss: 1.0009
  val_masked_loss: 1.0209
  val_full_loss: 1.0459
  val_loss: 1.0209


Training Epoch 8: 100%|██████████| 20/20 [00:08<00:00,  2.25it/s, loss=1.1602, avg_loss=1.0022]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.55it/s]


Epoch 8:
  train_loss: 1.0022
  val_masked_loss: 1.0235
  val_full_loss: 1.0478
  val_loss: 1.0235


Training Epoch 9: 100%|██████████| 20/20 [00:08<00:00,  2.25it/s, loss=0.8612, avg_loss=0.9964]
Validation: 100%|██████████| 5/5 [00:01<00:00,  3.52it/s]

Epoch 9:
  train_loss: 0.9964
  val_masked_loss: 1.0204
  val_full_loss: 1.0495
  val_loss: 1.0204
Training completed!
WiMAE training completed successfully!



