# Swin Transformer

In [None]:
# pip install xflow-py
from xflow import ConfigManager, FileProvider, PyTorchPipeline, show_model_info
from xflow.data import build_transforms_from_config
from xflow.trainers import build_callbacks_from_config, build_callbacks_from_config
from xflow.utils import load_validated_config, plot_image

from TM import TransmissionMatrix

# Configuration
config_manager = ConfigManager(load_validated_config("SwinT.yaml"))
config = config_manager.get()
config_manager.add_files(config["extra_files"])

# Data pipeline
provider = FileProvider(config["paths"]["dataset"]).subsample(fraction=1)
train_provider, temp_provider = provider.split(ratio=config["data"]["train_val_split"], 
                                               seed=config["seed"])
val_provider, test_provider = temp_provider.split(ratio=config["data"]["val_test_split"], 
                                                  seed=config["seed"])
transforms = build_transforms_from_config(config["data"]["transforms"]["torch"])

def make_dataset(provider):
    return PyTorchPipeline(provider, transforms).to_memory_dataset(config["data"]["dataset_ops"])

train_dataset = make_dataset(train_provider)
val_dataset = make_dataset(val_provider)
test_dataset = make_dataset(test_provider)

print("Samples: ",len(train_provider),len(val_provider),len(test_provider))
print("Batch: ",len(train_dataset),len(val_dataset),len(test_dataset))

for left_parts, right_parts in test_dataset:
    # batch will be a tuple: (right_halves, left_halves) due to split_width
    print(f"Batch shapes: {left_parts.shape}, {right_parts.shape}")
    plot_image(left_parts[0])
    plot_image(right_parts[0])
    break


In [None]:
# ====================================
# Model building
# ====================================
import torch
from SwinT import SwinV2AutoEncoder


# ----------------------------
# Explicit model construction
# ----------------------------
model_config = config["model"]
model = SwinV2AutoEncoder(
    img_size=model_config["img_size"],
    in_chans=model_config["in_chans"], 
    out_chans=model_config["out_chans"],
    patch_size=model_config["patch_size"],
    embed_dim=model_config["embed_dim"],
    depths=model_config["depths"],
    num_heads=model_config["num_heads"],
    window_size=model_config["window_size"],
    mlp_ratio=model_config["mlp_ratio"],
    qkv_bias=model_config["qkv_bias"],
    drop_path=model_config["drop_path"],
    decoder_channels=model_config["decoder_channels"],
    final_activation=model_config["final_activation"]
)

show_model_info(model)


# ====================================
# Training loop prepare
# ====================================
import xflow.extensions.physics

# criterion = torch.nn.MSELoss()  # pixel level MSE
criterion = torch.nn.L1Loss()  # start with L1
optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate'])

callbacks = build_callbacks_from_config(
    config=config["callbacks"],
    framework=config["framework"],
)
callbacks[-1].set_dataset(test_dataset)  # add dataset closure to the last callback

## Training

In [4]:
# debugging
import os
import matplotlib.pyplot as plt

def save_feature_map(features, epoch, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    plt.imsave(
        os.path.join(output_dir, f'epoch_{epoch}_featuremap.png'),
        features[0, 0].detach().cpu().numpy(),
        cmap='viridis'
    )

In [None]:



# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for cb in callbacks:
    cb.on_train_begin(epochs=config['training']['epochs'])  # Pass total epochs

for epoch in range(config['training']['epochs']):
    for cb in callbacks:
        cb.on_epoch_begin(epoch, model=model, total_batches=len(train_dataset))

    model.train()
    train_loss_sum, n_train = 0.0, 0
    
    for batch_idx, batch in enumerate(train_dataset):
        for cb in callbacks: 
            cb.on_batch_begin(batch_idx)

        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        # ----------------- debugging -----------------
        features = model.get_bottleneck_features(inputs)  # shape: [B, 768, 8, 8]
        save_feature_map(features, epoch, config["paths"]["output"])
        # Inspect decoder weights statistics
        model.print_decoder_weights_stats()
        # --------------------------------------------

        optimizer.zero_grad()                      # 1. clear gradients
        outputs = model(inputs)                    # 2. forward pass
        loss = criterion(outputs, targets)         # 3. compute loss
        loss.backward()                            # 4. backprop
        optimizer.step()                           # 5. update weights

        # Convert tensor to float for logging
        for cb in callbacks:
            cb.on_batch_end(batch_idx, logs={"loss": loss.item()})

        # accumulate for epoch average
        train_loss_sum += loss.item()
        n_train += 1

    avg_train_loss = train_loss_sum / max(1, n_train)

    # ----- validation pass (no grad) -----
    model.eval()
    val_loss_sum, n_val = 0.0, 0
    with torch.no_grad():
        for val_inputs, val_targets in val_dataset:       # assume val_dataset is ready
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss_sum += criterion(val_outputs, val_targets).item()
            n_val += 1
    avg_val_loss = val_loss_sum / max(1, n_val)

    # Add epoch-level metrics if available
    for cb in callbacks:
        cb.on_epoch_end(epoch, logs={
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss
        })

# Call once after training
for cb in callbacks:
    cb.on_train_end()

print("Training complete.")

# Save final trained model for inference
import os
os.makedirs(config["paths"]["output"], exist_ok=True)
torch.save(model.state_dict(), os.path.join(config["paths"]["output"], 'swin_model.pth'))

config_manager.save(output_dir=config["paths"]["output"])

## Evaluation

In [None]:
from SwinT import build_model
import torch

# Rebuild model and load weights
model = build_model(config['model'])
model.load_state_dict(torch.load('results/swin_model.pth', map_location=device))
model = model.to(device)
model.eval()

# Test reconstruction
with torch.no_grad():
    for inputs, targets in test_dataset:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        
        # Visualize results
        print(f"Test MSE Loss: {torch.nn.MSELoss()(outputs, targets).item():.6f}")
        
        # Plot some examples
        from xflow.utils import plot_image
        plot_image(inputs[:4].cpu())  # Input images
        plot_image(outputs[:4].cpu())  # Reconstructed images
        plot_image(targets[:4].cpu())  # Target images
        break