## Training Demonstration for WiMAE and ContraWiMAE

- This notebook demonstrates how to train both WiMAE and ContraWiMAE models.

#### Imports and Setup

In [1]:
import sys
import torch
import numpy as np
import yaml
import pprint
from pathlib import Path

# For Jupyter notebooks
sys.path.append(str(Path().cwd().parent))

# WiMAE imports
from contrawimae.training.train_wimae import WiMAETrainer
from contrawimae.training.train_contramae import ContraWiMAETrainer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.5.1+cu121
CUDA available: True


#### Data Overview

In [2]:
data_path = "../data/pretrain"
npz_files = list(Path(data_path).glob("*.npz"))

print(f"Available datasets: {len(npz_files)} cities")
for file in sorted(npz_files):
    file_size = file.stat().st_size / (1024*1024)  # MB
    print(f"  • {file.name}: {file_size:.1f} MB")

# Load one file to check data structure
with np.load(npz_files[0], allow_pickle=True) as sample_data:
    print(f"\nSample data structure from {npz_files[0].name}:")
    for key, value in sample_data.items():
        if key == "channels":
            print(f"  • {key}: {value.shape} ({value.dtype})")
        else:
            print(f"  • {key}: {value}")

Available datasets: 5 cities
  • boston5g_3p5_bs000.npz: 22.1 MB
  • city_44_lisboa_3p5_bs000.npz: 66.6 MB
  • city_50_edinburgh_3p5_bs000.npz: 63.8 MB
  • city_5_philadelphia_3p5_bs001.npz: 60.7 MB
  • city_66_bruxelles_3p5_bs000.npz: 42.8 MB

Sample data structure from city_44_lisboa_3p5_bs000.npz:
  • channels: (9283, 1, 32, 32) (complex64)
  • rx_pos: [[ -11.6456 -112.351     1.5   ]
 [ -11.6456 -111.351     1.5   ]
 [ -11.6456 -110.351     1.5   ]
 ...
 [ -18.6456  131.649     1.5   ]
 [ -19.6456  132.649     1.5   ]
 [ -20.6456  133.649     1.5   ]]
  • tx_pos: [[ 6.29055e-03 -8.95592e-02  1.00000e+01]]
  • los: [0 0 0 ... 0 0 0]
  • scenario_name: city_44_lisboa_3p5
  • bs_index: 0
  • scaling_factor: 1000000.0
  • active_mask_original_indices: [14267 14529 14791 ... 78188 78449 78710]
  • total_users_original: 87246
  • active_users_count: 9283
  • ch_params_info: {'subcarriers': 32, 'bandwidth': 960000, 'num_paths': 20, 'bs_antenna_shape': array([32,  1]), 'bs_antenna_rotation

#### Configuration Setup

In [3]:
config_path = "../configs/default_training.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Adjust config for demo (shorter training)
config['model']['encoder_layers'] = 2
config['model']['encoder_nhead'] = 4
config['model']['decoder_layers'] = 1
config['model']['decoder_nhead'] = 2
config['model']['mask_ratio'] = 0.9

config['training']['epochs'] = 1
config['training']['batch_size'] = 64
config['training']['device'] = "cpu"

config['data']['data_dir'] = data_path

config['logging']['exp_name'] = "demo_experiment"

pprint.pprint(config, sort_dicts=False)

# where to save checkpoints
checkpoint_path = Path(f"./{config['logging']['log_dir']}") / f"{config['model']['type']}_{config['logging']['exp_name']}"


{'model': {'type': 'wimae',
           'patch_size': [16, 1],
           'encoder_dim': 64,
           'encoder_layers': 2,
           'encoder_nhead': 4,
           'decoder_layers': 1,
           'decoder_nhead': 2,
           'mask_ratio': 0.9,
           'contrastive_dim': 64,
           'temperature': 0.2,
           'snr_min': 5.0,
           'snr_max': 40.0},
 'data': {'data_dir': '../data/pretrain',
          'normalize': True,
          'val_split': 0.1,
          'debug_size': None,
          'calculate_statistics': True,
          'statistics': {'real_mean': 0.021121172234416008,
                         'real_std': 30.7452392578125,
                         'imag_mean': -0.01027622725814581,
                         'imag_std': 30.70543670654297}},
 'training': {'batch_size': 64,
              'epochs': 1,
              'num_workers': 4,
              'device': 'cpu',
              'optimizer': {'type': 'adam',
                            'lr': 0.0003,
                     

#### WiMAE Training Setup

In [4]:
print("Setting up WiMAE training...")

# Create WiMAE trainer (model will be created during initialization)
wimae_trainer = WiMAETrainer(config=config)

print("WiMAE trainer initialized")

# Get model information
wimae_info = wimae_trainer.model.get_model_info()
print(f"WiMAE model created:")
for key, value in wimae_info.items():
    print(f"  • {key}: {value}")


Setting up WiMAE training...
WiMAE trainer initialized
WiMAE model created:
  • model_type: WiMAE
  • patch_size: (16, 1)
  • encoder_dim: 64
  • encoder_layers: 2
  • encoder_nhead: 4
  • decoder_layers: 1
  • decoder_nhead: 2
  • mask_ratio: 0.9
  • total_parameters: 118992
  • trainable_parameters: 118992


####  WiMAE Training Execution

In [5]:
print("Starting WiMAE training...")

try:
    # Start training (dataloaders will be set up automatically)
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")

Starting WiMAE training...


INFO - Computing statistics from training dataset...
INFO - Calculated statistics: {'real_mean': -0.06437671929597855, 'real_std': 26.93267822265625, 'imag_mean': -0.012862714007496834, 'imag_std': 26.949743270874023}
INFO - Train samples: 32064
INFO - Validation samples: 3562
INFO - Starting training for 1 epochs...
INFO - Model: wimae
INFO - Device: cpu
INFO - Log directory: runs/wimae_demo_experiment
Training Epoch 0: 100%|██████████| 501/501 [00:12<00:00, 38.93it/s, loss=0.2823, avg_loss=0.8797]
Validation: 100%|██████████| 56/56 [00:00<00:00, 187.02it/s]
INFO - Epoch 0:
INFO -   train_loss: 0.8797
INFO -   val_loss: 0.8189
INFO - Training completed!


WiMAE training completed successfully!


#### Checkpoints and Model Loading

In [6]:
print("Checkpoint Management")
print("=" * 30)

# Check available checkpoints
wimae_checkpoints = list(checkpoint_path.glob("*.pt")) if checkpoint_path.exists() else []

print(f"WiMAE checkpoints: {len(wimae_checkpoints)}")
for ckpt in wimae_checkpoints:
    print(f"  • {ckpt.name}")

Checkpoint Management
WiMAE checkpoints: 2
  • last_checkpoint.pt
  • best_checkpoint.pt


In [7]:
# Example of loading a checkpoint into WiMAE
best_checkpoint_path = checkpoint_path / "best_checkpoint.pt"
wimae_trainer.load_checkpoint(best_checkpoint_path)


# continue training
print("Continuing training...")

try:
    wimae_trainer.config["training"]["epochs"] = 1
    wimae_trainer.train()
    print("WiMAE training completed successfully!")
    
except Exception as e:
    print(f"Training failed: {e}")()

  checkpoint = torch.load(checkpoint_path, map_location=self.device)
INFO - Loaded full training state from epoch 0


Continuing training...


INFO - Computing statistics from training dataset...
INFO - Calculated statistics: {'real_mean': -0.06437671929597855, 'real_std': 26.93267822265625, 'imag_mean': -0.012862714007496834, 'imag_std': 26.949743270874023}
INFO - Train samples: 32064
INFO - Validation samples: 3562
INFO - Starting training for 1 epochs...
INFO - Model: wimae
INFO - Device: cpu
INFO - Log directory: runs/wimae_demo_experiment
Training Epoch 0: 100%|██████████| 501/501 [00:13<00:00, 35.82it/s, loss=0.4818, avg_loss=0.8368]
Validation: 100%|██████████| 56/56 [00:00<00:00, 142.48it/s]
INFO - Epoch 0:
INFO -   train_loss: 0.8368
INFO -   val_loss: 0.8116
INFO - Training completed!


WiMAE training completed successfully!


#### ContraWiMAE Trainer Setup

In [8]:
print("Setting up ContraWiMAE training...")

# adjust config for contra wimae
config['model']['type'] = "contrawimae"

# Create WiMAE trainer (model will be created during initialization)
contra_wimae_trainer = ContraWiMAETrainer(config=config)

# load wimae encoder and decoder weights
# strict=False because we are not loading the contrastive head
# model_only=True because we are not loading the training state
contra_wimae_trainer.load_checkpoint(best_checkpoint_path, model_only=True, strict=False)

print("ContraWiMAE trainer initialized")

# Get model information
contra_wimae_info = contra_wimae_trainer.model.get_model_info()
print(f"ContraWiMAE model created:")
for key, value in contra_wimae_info.items():
    print(f"  • {key}: {value}")


INFO - Loaded model weights only (training state not restored)


Setting up ContraWiMAE training...
ContraWiMAE trainer initialized
ContraWiMAE model created:
  • model_type: ContraWiMAE
  • patch_size: (16, 1)
  • encoder_dim: 64
  • encoder_layers: 2
  • encoder_nhead: 4
  • decoder_layers: 1
  • decoder_nhead: 2
  • mask_ratio: 0.9
  • total_parameters: 135568
  • trainable_parameters: 135568
  • contrastive_dim: 64
  • temperature: 0.2
  • snr_min: 5.0
  • snr_max: 40.0


In [9]:
# training contra wimae with wimae encoder and decoder weights
contra_wimae_trainer.config["training"]["epochs"] = 1
contra_wimae_trainer.train()
print("ContraWiMAE training completed successfully!")

INFO - Computing statistics from training dataset...
INFO - Calculated statistics: {'real_mean': -0.06437671929597855, 'real_std': 26.93267822265625, 'imag_mean': -0.012862714007496834, 'imag_std': 26.949743270874023}
INFO - Train samples: 32064
INFO - Validation samples: 3562
INFO - Starting training for 1 epochs...
INFO - Model: contrawimae
INFO - Device: cpu
INFO - Log directory: runs/contrawimae_demo_experiment
Training Epoch 0: 100%|██████████| 501/501 [00:23<00:00, 21.14it/s, recon_loss=0.7048, contrastive_loss=2.3269, total_loss=0.8670, avg_total_loss=1.0037]
Validation: 100%|██████████| 56/56 [00:00<00:00, 75.00it/s]
INFO - Epoch 0:
INFO -   train_recon_loss: 0.8320
INFO -   train_contrastive_loss: 2.5486
INFO -   train_total_loss: 1.0037
INFO -   val_recon_loss: 0.8120
INFO -   val_contrastive_loss: 2.3479
INFO -   val_loss: 0.9656
INFO - Training completed!


ContraWiMAE training completed successfully!
