## Training Demonstration for WiMAE and ContraWiMAE

- This notebook demonstrates how to train both WiMAE and ContraWiMAE models
using real wireless channel data from the pretrain folder.

#### Imports and Setup

In [1]:
import sys
import torch
import numpy as np
import yaml
import pprint
from pathlib import Path

# Add parent directory to path for imports
try:
    # For Python scripts
    sys.path.append(str(Path(__file__).parent.parent))
except NameError:
    # For Jupyter notebooks
    sys.path.append(str(Path().cwd().parent))

# WiMAE imports
from wimae.training.train_wimae import WiMAETrainer
from wimae.training.train_contramae import ContraWiMAETrainer

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

All imports successful!
PyTorch version: 2.5.1
CUDA available: True


#### Data Overview

In [2]:
data_path = "../data/pretrain"
npz_files = list(Path(data_path).glob("*.npz"))

print(f"Available datasets: {len(npz_files)} cities")
for file in sorted(npz_files):
    file_size = file.stat().st_size / (1024*1024)  # MB
    print(f"  • {file.name}: {file_size:.1f} MB")

# Load one file to check data structure
with np.load(npz_files[0]) as sample_data:
    print(f"\nSample data structure from {npz_files[0].name}:")
    for key, value in sample_data.items():
        if key == "channels":
            print(f"  • {key}: {value.shape} ({value.dtype})")
        else:
            print(f"  • {key}: {value}")

Available datasets: 12 cities
  • city_0_newyork_channels.npz: 10.0 MB
  • city_10_austin_channels.npz: 14.5 MB
  • city_13_columbus_channels.npz: 11.2 MB
  • city_17_seattle_channels.npz: 11.5 MB
  • city_1_losangeles_channels.npz: 5.8 MB
  • city_2_chicago_channels.npz: 2.2 MB
  • city_3_houston_channels.npz: 20.1 MB
  • city_4_phoenix_channels.npz: 21.1 MB
  • city_5_philadelphia_channels.npz: 4.9 MB
  • city_6_miami_channels.npz: 13.1 MB
  • city_8_dallas_channels.npz: 19.6 MB
  • city_9_sanfrancisco_channels.npz: 13.0 MB

Sample data structure from city_0_newyork_channels.npz:
  • channels: (1283, 1, 32, 32) (complex64)
  • name: city_0_newyork
  • n_rows: [ 0 44]
  • n_per_row: 117
  • active_bs: [1]
  • n_ant_bs: 32
  • n_ant_ue: 1
  • n_subcarriers: 32
  • scs: 30000.0
  • data_folder: ./data/scenarios
  • bs_rotation: [   0    0 -135]
  • enable_bs2bs: False
  • num_paths: 20


#### Configuration Setup

In [3]:
config_path = "../configs/default_training.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Default configuration loaded successfully:")
pprint.pprint(config, sort_dicts=False)


# Adjust config for demo (shorter training)
config['model']['encoder_layers'] = 2
config['model']['encoder_nhead'] = 4
config['model']['decoder_layers'] = 1
config['model']['decoder_nhead'] = 4
config['model']['mask_ratio'] = 0.6

config['training']['epochs'] = 3
config['training']['batch_size'] = 768
config['training']['device'] = "cpu"

config['data']['data_dir'] = data_path
config['data']['debug_size'] = config['training']['batch_size'] * 5

config['logging']['exp_name'] = "demo_exp"

print(f"\nAdjusted for demo:")
print(f"Encoder dim: {config['model']['encoder_dim']}")
print(f"Encoder layers: {config['model']['encoder_layers']}")
print(f"Encoder nhead: {config['model']['encoder_nhead']}")
print(f"Decoder layers: {config['model']['decoder_layers']}")
print(f"Decoder nhead: {config['model']['decoder_nhead']}")
print(f"Mask ratio: {config['model']['mask_ratio']}")
print(f"Epochs: {config['training']['epochs']}")
print(f"Batch size: {config['training']['batch_size']}")
print(f"Device: {config['training']['device']}")
print(f"Data directory: {config['data']['data_dir']}")
print(f"Debug data size: {config['data']['debug_size']}")
print(f"Experiment folder name: {config['logging']['exp_name']}")

# where to save checkpoints
checkpoint_path = Path(f"./{config['logging']['log_dir']}") / f"{config['model']['type']}_{config['logging']['exp_name']}"


Default configuration loaded successfully:
{'model': {'type': 'wimae',
           'patch_size': [1, 16],
           'encoder_dim': 64,
           'encoder_layers': 12,
           'encoder_nhead': 16,
           'decoder_layers': 4,
           'decoder_nhead': 8,
           'mask_ratio': 0.6,
           'contrastive_dim': 64,
           'temperature': 0.1,
           'snr_min': 0.0,
           'snr_max': 30.0},
 'data': {'data_dir': 'data/pretrain',
          'normalize': True,
          'val_split': 0.2,
          'debug_size': None,
          'calculate_statistics': True,
          'statistics': {'real_mean': 0.021121172234416008,
                         'real_std': 30.7452392578125,
                         'imag_mean': -0.01027622725814581,
                         'imag_std': 30.70543670654297}},
 'training': {'batch_size': 64,
              'epochs': 3000,
              'num_workers': 4,
              'device': 'cuda:0',
              'optimizer': {'type': 'adam',
               

#### WiMAE Training Setup

In [4]:
print("Setting up WiMAE training...")

# Create WiMAE trainer (model will be created during initialization)
wimae_trainer = WiMAETrainer(config=config)

print("WiMAE trainer initialized")

# Get model information
wimae_info = wimae_trainer.model.get_model_info()
print(f"WiMAE model created:")
for key, value in wimae_info.items():
    print(f"  • {key}: {value}")


Setting up WiMAE training...
WiMAE trainer initialized
WiMAE model created:
  • model_type: WiMAE
  • patch_size: (1, 16)
  • encoder_dim: 64
  • encoder_layers: 2
  • encoder_nhead: 4
  • decoder_layers: 1
  • decoder_nhead: 4
  • mask_ratio: 0.6
  • total_parameters: 118992
  • trainable_parameters: 118992


####  WiMAE Training Execution

In [5]:
print("Starting WiMAE training...")

try:
    # Start training (dataloaders will be set up automatically)
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")

Starting WiMAE training...
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
Loading file 3/12: city_10_austin_channels.npz
Loading file 4/12: city_9_sanfrancisco_channels.npz
Loading file 5/12: city_6_miami_channels.npz
Loading file 6/12: city_13_columbus_channels.npz
Loading file 7/12: city_2_chicago_channels.npz
Loading file 8/12: city_5_philadelphia_channels.npz
Loading file 9/12: city_17_seattle_channels.npz
Loading file 10/12: city_3_houston_channels.npz
Loading file 11/12: city_8_dallas_channels.npz
Loading file 12/12: city_4_phoenix_channels.npz
Successfully loaded all 18824 samples
Computing statistics from training dataset...
Calculated statistics: {'real_mean': 0.0013652583584189415, 'real_std': 30.900094985961914, 'imag_mean': -0.0803627148270607, 'imag_std': 30.591354370117188}
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_

Training Epoch 0: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=1.2166, avg_loss=1.1771]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.42it/s]


Epoch 0:
  train_loss: 1.1771
  val_masked_loss: 1.1310
  val_full_loss: 1.2507
  val_loss: 1.1310


Training Epoch 1: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=1.0080, avg_loss=1.0571]
Validation: 100%|██████████| 1/1 [00:00<00:00,  4.01it/s]


Epoch 1:
  train_loss: 1.0571
  val_masked_loss: 1.1170
  val_full_loss: 1.2334
  val_loss: 1.1170


Training Epoch 2: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s, loss=0.8479, avg_loss=1.0528]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.91it/s]

Epoch 2:
  train_loss: 1.0528
  val_masked_loss: 1.1000
  val_full_loss: 1.2167
  val_loss: 1.1000
Training completed!
WiMAE training completed successfully!





#### Checkpoints and Model Loading

In [6]:
print("Checkpoint Management")
print("=" * 30)

# Check available checkpoints
wimae_checkpoints = list(checkpoint_path.glob("*.pt")) if checkpoint_path.exists() else []

print(f"WiMAE checkpoints: {len(wimae_checkpoints)}")
for ckpt in wimae_checkpoints:
    print(f"  • {ckpt.name}")

Checkpoint Management
WiMAE checkpoints: 2
  • last_checkpoint.pt
  • best_checkpoint.pt


In [7]:
# Example of loading a checkpoint into WiMAE
best_checkpoint_path = checkpoint_path / "best_checkpoint.pt"
wimae_trainer.load_checkpoint(best_checkpoint_path)


# continue training
print("Continuing training...")

try:
    wimae_trainer.config["training"]["epochs"] = 10
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")()

  checkpoint = torch.load(checkpoint_path, map_location=self.device)


Loaded full training state from epoch 2
Continuing training...
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
Loading file 3/12: city_10_austin_channels.npz
Loading file 4/12: city_9_sanfrancisco_channels.npz
Loading file 5/12: city_6_miami_channels.npz
Loading file 6/12: city_13_columbus_channels.npz
Loading file 7/12: city_2_chicago_channels.npz
Loading file 8/12: city_5_philadelphia_channels.npz
Loading file 9/12: city_17_seattle_channels.npz
Loading file 10/12: city_3_houston_channels.npz
Loading file 11/12: city_8_dallas_channels.npz
Loading file 12/12: city_4_phoenix_channels.npz
Successfully loaded all 18824 samples
Computing statistics from training dataset...
Calculated statistics: {'real_mean': 0.0013652583584189415, 'real_std': 30.900094985961914, 'imag_mean': -0.0803627148270607, 'imag_std': 30.591354370117188}
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_c

Training Epoch 2: 100%|██████████| 4/4 [00:02<00:00,  1.36it/s, loss=1.0797, avg_loss=1.0401]
Validation: 100%|██████████| 1/1 [00:00<00:00,  2.53it/s]


Epoch 2:
  train_loss: 1.0401
  val_masked_loss: 1.0922
  val_full_loss: 1.2120
  val_loss: 1.0922


Training Epoch 3: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s, loss=1.2286, avg_loss=1.0374]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.94it/s]


Epoch 3:
  train_loss: 1.0374
  val_masked_loss: 1.1004
  val_full_loss: 1.2109
  val_loss: 1.1004


Training Epoch 4: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=0.9134, avg_loss=1.0327]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.97it/s]


Epoch 4:
  train_loss: 1.0327
  val_masked_loss: 1.0929
  val_full_loss: 1.2077
  val_loss: 1.0929


Training Epoch 5: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=1.0557, avg_loss=1.0312]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.96it/s]


Epoch 5:
  train_loss: 1.0312
  val_masked_loss: 1.0931
  val_full_loss: 1.2038
  val_loss: 1.0931


Training Epoch 6: 100%|██████████| 4/4 [00:02<00:00,  1.42it/s, loss=0.8846, avg_loss=1.0295]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.94it/s]


Epoch 6:
  train_loss: 1.0295
  val_masked_loss: 1.0871
  val_full_loss: 1.2013
  val_loss: 1.0871


Training Epoch 7: 100%|██████████| 4/4 [00:02<00:00,  1.43it/s, loss=1.0135, avg_loss=1.0180]
Validation: 100%|██████████| 1/1 [00:00<00:00,  4.04it/s]


Epoch 7:
  train_loss: 1.0180
  val_masked_loss: 1.0911
  val_full_loss: 1.1994
  val_loss: 1.0911


Training Epoch 8: 100%|██████████| 4/4 [00:02<00:00,  1.47it/s, loss=1.0084, avg_loss=1.0222]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s]


Epoch 8:
  train_loss: 1.0222
  val_masked_loss: 1.0857
  val_full_loss: 1.1986
  val_loss: 1.0857


Training Epoch 9: 100%|██████████| 4/4 [00:02<00:00,  1.36it/s, loss=1.0364, avg_loss=1.0229]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.86it/s]

Epoch 9:
  train_loss: 1.0229
  val_masked_loss: 1.0953
  val_full_loss: 1.1970
  val_loss: 1.0953
Training completed!
WiMAE training completed successfully!





#### ContraWiMAE Trainer Setup

In [10]:
print("Setting up ContraWiMAE training...")

# adjust config for contra wimae
config['model']['type'] = "contramae"

# Create WiMAE trainer (model will be created during initialization)
contra_wimae_trainer = ContraWiMAETrainer(config=config)

# load wimae encoder and decoder weights
# strict=False because we are not loading the contrastive head
# model_only=True because we are not loading the training state
contra_wimae_trainer.load_checkpoint(best_checkpoint_path, model_only=True, strict=False)

print("ContraWiMAE trainer initialized")

# Get model information
contra_wimae_info = contra_wimae_trainer.model.get_model_info()
print(f"ContraWiMAE model created:")
for key, value in contra_wimae_info.items():
    print(f"  • {key}: {value}")


Setting up ContraWiMAE training...
  - contrastive_head.proj_head.0.weight
  - contrastive_head.proj_head.0.bias
  - contrastive_head.proj_head.2.weight
  - contrastive_head.proj_head.2.bias
Loaded model weights only (training state not restored)
ContraWiMAE trainer initialized
ContraWiMAE model created:
  • model_type: ContraWiMAE
  • patch_size: (1, 16)
  • encoder_dim: 64
  • encoder_layers: 2
  • encoder_nhead: 4
  • decoder_layers: 1
  • decoder_nhead: 4
  • mask_ratio: 0.6
  • total_parameters: 135568
  • trainable_parameters: 135568
  • contrastive_dim: 64
  • temperature: 0.1
  • snr_min: 0.0
  • snr_max: 30.0


In [11]:
# training contra wimae with wimae encoder and decoder weights
contra_wimae_trainer.config["training"]["epochs"] = 5
contra_wimae_trainer.train()
print("ContraWiMAE training completed successfully!")

Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
Loading file 3/12: city_10_austin_channels.npz
Loading file 4/12: city_9_sanfrancisco_channels.npz
Loading file 5/12: city_6_miami_channels.npz
Loading file 6/12: city_13_columbus_channels.npz
Loading file 7/12: city_2_chicago_channels.npz
Loading file 8/12: city_5_philadelphia_channels.npz
Loading file 9/12: city_17_seattle_channels.npz
Loading file 10/12: city_3_houston_channels.npz
Loading file 11/12: city_8_dallas_channels.npz
Loading file 12/12: city_4_phoenix_channels.npz
Successfully loaded all 18824 samples
Computing statistics from training dataset...
Calculated statistics: {'real_mean': 0.0013652583584189415, 'real_std': 30.900094985961914, 'imag_mean': -0.0803627148270607, 'imag_std': 30.591354370117188}
Total samples: 18824, dimensions: 32x32
Loading file 1/12: city_0_newyork_channels.npz
Loading file 2/12: city_1_losangeles_channels.npz
L

Training Epoch 0: 100%|██████████| 4/4 [00:04<00:00,  1.21s/it, recon_loss=0.8723, contrastive_loss=6.8901, total_loss=1.4741, avg_total_loss=1.6252]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.35it/s]


Epoch 0:
  train_recon_loss: 1.0233
  train_contrastive_loss: 7.0419
  train_total_loss: 1.6252
  val_masked_recon_loss: 1.0880
  val_full_recon_loss: 1.2089
  val_contrastive_loss: 6.7856
  val_masked_loss: 1.6578
  val_full_loss: 1.7666
  val_loss: 1.6578


Training Epoch 1: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it, recon_loss=1.0353, contrastive_loss=6.4769, total_loss=1.5794, avg_total_loss=1.5819]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s]


Epoch 1:
  train_recon_loss: 1.0226
  train_contrastive_loss: 6.6157
  train_total_loss: 1.5819
  val_masked_recon_loss: 1.0980
  val_full_recon_loss: 1.2098
  val_contrastive_loss: 6.3382
  val_masked_loss: 1.6220
  val_full_loss: 1.7227
  val_loss: 1.6220


Training Epoch 2: 100%|██████████| 4/4 [00:04<00:00,  1.24s/it, recon_loss=0.9224, contrastive_loss=6.2361, total_loss=1.4538, avg_total_loss=1.5478]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


Epoch 2:
  train_recon_loss: 1.0204
  train_contrastive_loss: 6.2942
  train_total_loss: 1.5478
  val_masked_recon_loss: 1.0902
  val_full_recon_loss: 1.2053
  val_contrastive_loss: 6.1315
  val_masked_loss: 1.5943
  val_full_loss: 1.6979
  val_loss: 1.5943


Training Epoch 3: 100%|██████████| 4/4 [00:04<00:00,  1.23s/it, recon_loss=1.0305, contrastive_loss=6.0666, total_loss=1.5341, avg_total_loss=1.5263]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]


Epoch 3:
  train_recon_loss: 1.0149
  train_contrastive_loss: 6.1283
  train_total_loss: 1.5263
  val_masked_recon_loss: 1.0943
  val_full_recon_loss: 1.1994
  val_contrastive_loss: 5.9995
  val_masked_loss: 1.5848
  val_full_loss: 1.6794
  val_loss: 1.5848


Training Epoch 4: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it, recon_loss=1.0315, contrastive_loss=5.9764, total_loss=1.5260, avg_total_loss=1.5172]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]

Epoch 4:
  train_recon_loss: 1.0180
  train_contrastive_loss: 6.0101
  train_total_loss: 1.5172
  val_masked_recon_loss: 1.0965
  val_full_recon_loss: 1.1957
  val_contrastive_loss: 5.9016
  val_masked_loss: 1.5770
  val_full_loss: 1.6663
  val_loss: 1.5770
Training completed!
ContraWiMAE training completed successfully!



