# TorchModelWrapper Example

This notebook explains how to use ExECG's `TorchModelWrapper` with various examples.

## Contents
1. Basic Usage
2. Using Registered Models (afib_binary, potassium_regression)
3. Using Preprocess (input shape transformation)
4. Using Postprocess (output shape transformation)
5. Special Case Model Examples
6. Extracting Gradients and Layer Gradients

In [None]:
import sys

sys.path.append("../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from execg.models.wrapper import TorchModelWrapper

## 1. Basic Usage

### Input/Output Convention
- **Input**: `(1, n_leads, seq_length)` e.g., `(1, 12, 2500)`
- **Output**: `(1, N)` where N is:
  - Regression: N=1
  - Binary: N=2 (probabilities)
  - Multiclass/Multilabel: N=num_classes (probabilities)

In [None]:
# Define a simple test model
class SimpleBinaryModel(nn.Module):
    """Simple Binary Classification model (standard input/output)"""

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv1d(12, 32, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(32, 2)

    def forward(self, x):
        # x: (1, 12, length)
        x = F.relu(self.conv(x))
        x = self.pool(x).squeeze(-1)
        x = self.fc(x)
        return F.softmax(x, dim=-1)  # (1, 2)


# Create and wrap model
model = SimpleBinaryModel()
wrapper = TorchModelWrapper(model)

print(f"Wrapper: {wrapper}")
print(f"Device: {wrapper.device}")

In [None]:
# Prediction test
ecg = torch.randn(1, 12, 2500)  # (1, lead, length)
output = wrapper.predict(ecg)

print(f"Input shape: {ecg.shape}")
print(f"Output shape: {output.shape}")
print(f"Output: {output}")

## 2. Using Registered Models

Examples of using pre-trained models registered in ExECG.

In [None]:
from execg.misc import get_model
from samples.models.registry import MODEL_REGISTRY

print(f"Available models: {list(MODEL_REGISTRY.keys())}")

### 2.1 AFib Binary Classification Model

In [None]:
# Load AFib classification model
afib_model = get_model(
    name="afib_binary",
    model_dir="../tmp/models/afib_binary/",
    registry=MODEL_REGISTRY,
    download=True,
)

# Wrap (apply softmax with postprocess)
afib_wrapper = TorchModelWrapper(afib_model)

print(f"Model: {afib_wrapper}")

In [None]:
# Prediction test
ecg_afib = torch.randn(1, 12, 2500)
pred = afib_wrapper.predict(ecg_afib)

print(f"Input shape: {ecg_afib.shape}")
print(f"Output shape: {pred.shape}")
print(f"Prediction (prob): {pred}")
print(
    f"Predicted class: {pred.argmax().item()} ({'AFib' if pred.argmax().item() == 1 else 'Normal'})"
)

### 2.2 Potassium Regression Model

In [None]:
# Load Potassium regression model
potassium_model = get_model(
    name="potassium_regression",
    model_dir="../tmp/models/potassium_regression/",
    registry=MODEL_REGISTRY,
    download=True,
)

potassium_wrapper = TorchModelWrapper(potassium_model)

print(f"Model: {potassium_wrapper}")

In [None]:
# Prediction test
ecg_k = torch.randn(1, 12, 2500)
pred_k = potassium_wrapper.predict(ecg_k)

print(f"Input shape: {ecg_k.shape}")
print(f"Output shape: {pred_k.shape}")
print(f"Predicted potassium level: {pred_k.item():.2f} mEq/L")

## 3. Using Preprocess

Use `preprocess` when the model expects a different input shape.

### 3.1 Case: Model expects (1, length, lead) instead of (1, lead, length)

In [None]:
class TransposedInputModel(nn.Module):
    """Model that receives input in (1, length, lead) shape"""

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(12, 64, batch_first=True)
        self.fc = nn.Linear(64, 2)

    def forward(self, x):
        # x: (1, length, 12)
        _, (h, _) = self.lstm(x)
        return F.softmax(self.fc(h.squeeze(0)), dim=-1)


# Create model
transposed_model = TransposedInputModel()

# Use preprocess to transform (1, lead, length) -> (1, length, lead)
transposed_wrapper = TorchModelWrapper(
    transposed_model,
    preprocess=lambda x: x.transpose(1, 2),  # (1, 12, L) -> (1, L, 12)
)

# Test: use standard input (1, 12, 2500)
ecg = torch.randn(1, 12, 2500)
output = transposed_wrapper.predict(ecg)

print(f"Standard input shape: {ecg.shape}")
print(f"Output shape: {output.shape}")
print(f"Output: {output}")

## 4. Using Postprocess

Transform model output to standard format `(1, N)`.

### 4.1 Case: Model outputs single logit (binary)

In [None]:
class SingleLogitModel(nn.Module):
    """Binary model that outputs a single logit"""

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv1d(12, 32, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(32, 1)  # single logit

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x).squeeze(-1)
        return self.fc(x)  # (1, 1) logit


single_logit_model = SingleLogitModel()


# Use postprocess to transform single logit -> (1, 2) binary probs
def logit_to_binary_probs(x):
    prob_pos = torch.sigmoid(x)
    prob_neg = 1 - prob_pos
    return torch.cat([prob_neg, prob_pos], dim=-1)  # (1, 2)


single_logit_wrapper = TorchModelWrapper(
    single_logit_model, postprocess=logit_to_binary_probs
)

# Test
ecg = torch.randn(1, 12, 2500)
output = single_logit_wrapper.predict(ecg)

print(f"Output shape: {output.shape}")
print(f"Output (binary probs): {output}")
print(f"Sum of probs: {output.sum().item():.4f} (should be 1.0)")

### 4.2 Case: prob + auxiliary output

In [None]:
class MultilabelModel(nn.Module):
    """Multilabel classification model"""

    def __init__(self, num_labels=4):
        super().__init__()
        self.conv = nn.Conv1d(12, 32, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(32, num_labels)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x).squeeze(-1)
        return torch.sigmoid(self.fc(x)), {}  # (1, num_labels) logits and etc


multilabel_model = MultilabelModel(num_labels=4)

# Use postprocess to apply sigmoid (independent probability for each label)
multilabel_wrapper = TorchModelWrapper(multilabel_model, postprocess=lambda x: x[0])

# Test
ecg = torch.randn(1, 12, 2500)
output = multilabel_wrapper.predict(ecg)

labels = ["AFib", "LBBB", "RBBB", "PVC"]
print(f"Output shape: {output.shape}")
print(f"Output (multilabel probs): {output}")
print("\nPredicted labels:")
for i, (label, prob) in enumerate(zip(labels, output[0])):
    print(f"  {label}: {prob.item():.3f} {'*' if prob > 0.5 else ''}")

## 5. Extracting Gradients and Layer Gradients

In [None]:
# Test with a simple model
class GradTestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(12, 32, kernel_size=7, padding=3)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x).squeeze(-1)
        return F.softmax(self.fc(x), dim=-1)


grad_model = GradTestModel()
grad_wrapper = TorchModelWrapper(grad_model)

### 5.1 get_layer_names()

In [None]:
layer_names = grad_wrapper.get_layer_names()
print("Available layers:")
for name in layer_names:
    print(f"  - {name if name else '(root)'}")

### 5.2 get_gradients()

In [None]:
ecg = torch.randn(1, 12, 2500)

# Gradient for a specific class
grads = grad_wrapper.get_gradients(ecg, target_class=1)

print(f"Input shape: {ecg.shape}")
print(f"Gradient shape: {grads.shape}")
print(f"Gradient range: [{grads.min():.6f}, {grads.max():.6f}]")

In [None]:
# If target_class=None, uses argmax class
grads_auto = grad_wrapper.get_gradients(ecg, target_class=None)
print(f"Auto target gradient shape: {grads_auto.shape}")

### 5.3 get_layer_gradients() - for Grad-CAM

In [None]:
# Extract activations and gradients from conv2 layer
activations, gradients = grad_wrapper.get_layer_gradients(
    ecg, target_class=1, layer_name="conv2"
)

print(f"Activations shape: {activations.shape}")
print(f"Gradients shape: {gradients.shape}")
print(f"\nThese can be used for Grad-CAM computation!")

### 5.4 Using output_idx

In [None]:
ecg = torch.randn(1, 12, 2500)

# Full output
full_output = grad_wrapper.predict(ecg)
print(f"Full output: {full_output}")

# Specific index only
output_0 = grad_wrapper.predict(ecg, output_idx=0)
output_1 = grad_wrapper.predict(ecg, output_idx=1)

print(f"Output[0]: {output_0}")
print(f"Output[1]: {output_1}")