# 🧠 Refactored DuoFormer: Multi-Scale Vision Transformer for Medical Imaging

This notebook demonstrates the refactored DuoFormer architecture - a multi-scale vision transformer designed for **general medical image classification**. DuoFormer combines the power of convolutional neural networks (CNNs) for feature extraction with multi-scale attention mechanisms for improved performance across various medical imaging modalities.

## 📊 Key Features

- **Multi-Scale Attention**: Processes features at multiple resolutions simultaneously
- **Hybrid Architecture**: Combines ResNet backbone with Vision Transformer
- **General Medical Imaging**: Works with histopathology, radiology, dermatology, ophthalmology, and more
- **Flexible Configuration**: Supports various backbone architectures and scales
- **Platform-Agnostic**: Auto-detects hardware (CUDA/MPS/CPU) and optimizes settings

> **Original Work**: [xiaoyatang/duoformer_TCGA](https://github.com/xiaoyatang/duoformer_TCGA) | **Paper**: [arXiv:2506.12982](https://arxiv.org/abs/2506.12982)

---


## 1️⃣ Setup and Installation

First, let's install the required dependencies and set up the environment.

### 📦 Dependency Management

This project uses `pip-tools` for reproducible dependency management:

- `requirements.in`: Lists direct dependencies with flexible version constraints
- `requirements.txt`: Auto-generated lockfile with pinned versions and hashes
- `setup_environment.py`: Automated setup script that handles the entire process

**Benefits:**

- ✅ Reproducible environments across different machines
- ✅ Automatic conflict resolution
- ✅ Security through dependency hashing
- ✅ Easy updates while maintaining compatibility


In [None]:
# 🔧 Install required packages using automated setup
# This script handles pip-tools, requirements compilation, and installation
# Run this cell once at the start of your session

import sys
import subprocess
from pathlib import Path

# Check if setup_environment.py exists
setup_script = Path("setup_environment.py")

if setup_script.exists():
    print("🚀 Running automated environment setup...")
    print("This will:")
    print("  1. Upgrade pip and install pip-tools")
    print("  2. Compile requirements.in to requirements.txt")
    print("  3. Install all dependencies with conflict resolution")
    print("  4. Validate the environment")
    print("\n" + "=" * 70 + "\n")

    # Run the setup script
    result = subprocess.run(
        [sys.executable, str(setup_script), "--verbose"], capture_output=False
    )

    if result.returncode == 0:
        print("\n✅ Environment setup completed successfully!")
    else:
        print("\n❌ Setup encountered an error. Please check the output above.")
else:
    print("⚠️  setup_environment.py not found!")
    print("Installing dependencies directly...")
    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "torch",
            "torchvision",
            "timm",
            "einops",
            "matplotlib",
            "numpy",
            "pillow",
            "tqdm",
            "scikit-learn",
        ]
    )

# Alternative: Manual installation (uncomment if needed)


# !pip install torch torchvision timm einops matplotlib numpy pillow tqdm scikit-learn

In [None]:
# Import necessary libraries
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim


from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
from torchvision import transforms


import numpy as np


import matplotlib.pyplot as plt


from PIL import Image
import os


import sys


from tqdm.notebook import tqdm


from typing import Optional, Tuple, List


import warnings


warnings.filterwarnings("ignore")


# Set up paths (platform-agnostic)
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))


# Configure device (automatically detects best available)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print(f"🔧 Using device: {device}")


if torch.cuda.is_available():

    print(f"   GPU: {torch.cuda.get_device_name(0)}")

    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"   CUDA version: {torch.version.cuda}")
else:
    print("   Running on CPU (CUDA not available)")
    print("   Note: Training will be slower without GPU")

## 2️⃣ Import DuoFormer Models

Let's import the DuoFormer architecture and explore its components.


In [None]:
# Import DuoFormer components
from models import (
    build_model,
    build_model_no_extra_params,
    build_hybrid,
    count_parameters,
)

print("✅ DuoFormer modules imported successfully!")

## 3️⃣ Create Sample Medical Image Dataset

For demonstration purposes, we'll create a synthetic medical image dataset. In practice, you would replace this with your actual medical imaging data (e.g., X-rays, CT scans, MRI images).


In [None]:
class MedicalImageDataset(Dataset):
    """
    Synthetic medical image dataset for demonstration.
    In practice, replace this with actual medical imaging data.
    """

    def __init__(
        self,
        num_samples: int = 1000,
        num_classes: int = 10,
        image_size: int = 224,
        transform: Optional[transforms.Compose] = None,
    ):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.image_size = image_size
        self.transform = transform or self._default_transform()

        # Generate synthetic data (replace with actual data loading)
        np.random.seed(42)
        self.images = np.random.randn(num_samples, 3, image_size, image_size).astype(
            np.float32
        )
        self.labels = np.random.randint(0, num_classes, num_samples)

        # Simulate different medical image patterns
        for i in range(num_samples):
            pattern_type = self.labels[i] % 4
            if pattern_type == 0:  # Circular patterns (tumors)
                self._add_circular_pattern(i)
            elif pattern_type == 1:  # Linear patterns (fractures)
                self._add_linear_pattern(i)
            elif pattern_type == 2:  # Diffuse patterns (inflammation)
                self._add_diffuse_pattern(i)
            else:  # Complex patterns
                self._add_complex_pattern(i)

    def _default_transform(self):
        """Default augmentation pipeline for medical images."""
        return transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def _add_circular_pattern(self, idx: int):
        """Simulate tumor-like circular patterns."""
        center = np.random.randint(50, 174, 2)
        radius = np.random.randint(10, 30)
        y, x = np.ogrid[: self.image_size, : self.image_size]
        mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
        self.images[idx, 0, mask] += np.random.uniform(0.5, 1.5)

    def _add_linear_pattern(self, idx: int):
        """Simulate fracture-like linear patterns."""
        start = np.random.randint(0, self.image_size, 2)
        end = np.random.randint(0, self.image_size, 2)
        rr, cc = np.linspace(start[0], end[0], 100).astype(int), np.linspace(
            start[1], end[1], 100
        ).astype(int)
        valid = (rr >= 0) & (rr < self.image_size) & (cc >= 0) & (cc < self.image_size)
        self.images[idx, 1, rr[valid], cc[valid]] += np.random.uniform(0.5, 1.5)

    def _add_diffuse_pattern(self, idx: int):
        """Simulate inflammation-like diffuse patterns."""
        noise = np.random.randn(self.image_size, self.image_size) * 0.3
        from scipy.ndimage import gaussian_filter

        smoothed = gaussian_filter(noise, sigma=5)
        self.images[idx, 2] += smoothed

    def _add_complex_pattern(self, idx: int):
        """Simulate complex medical patterns."""
        self._add_circular_pattern(idx)
        self._add_linear_pattern(idx)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int):
        image = self.images[idx].transpose(1, 2, 0)  # CHW -> HWC for PIL
        image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(
            np.uint8
        )

        if self.transform:
            image = self.transform(image)
        else:
            image = torch.from_numpy(self.images[idx])

        label = self.labels[idx]
        return image, label


print("✅ Medical image dataset class created!")

In [None]:
# Create dataset instances
full_dataset = MedicalImageDataset(
    num_samples=500, num_classes=10, image_size=224  # Reduced for demo
)

# Split into train, validation, and test sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

# Create data loaders
batch_size = 16
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)

print(f"📊 Dataset Statistics:")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(val_dataset)}")
print(f"   Test samples: {len(test_dataset)}")
print(f"   Batch size: {batch_size}")
print(f"   Number of classes: {10}")

## 4️⃣ Visualize Sample Medical Images

Let's visualize some sample images from our dataset to understand the data better.


In [None]:
def visualize_samples(dataset, num_samples: int = 6):
    """Visualize sample medical images from the dataset."""
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.ravel()

    for i in range(min(num_samples, len(dataset))):
        image, label = dataset[i]

        # Denormalize image for visualization
        if isinstance(image, torch.Tensor):
            image = image.numpy()

        if image.shape[0] == 3:  # CHW format
            image = image.transpose(1, 2, 0)

        # Clip and normalize for display
        image = np.clip(image, 0, 1) if image.max() <= 1 else image / 255.0

        axes[i].imshow(image)
        axes[i].set_title(f"Class: {label}", fontsize=10)
        axes[i].axis("off")

    plt.suptitle("Sample Medical Images", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()


visualize_samples(full_dataset)

## 5️⃣ Initialize DuoFormer Model

Now let's create and configure the DuoFormer model with different settings.


In [None]:
# Model configuration
config = {
    "depth": 12,  # Number of transformer blocks
    "embed_dim": 384,  # Embedding dimension
    "num_heads": 6,  # Number of attention heads
    "num_classes": 10,  # Number of output classes
    "num_layers": 2,  # Number of scales (2, 3, or 4)
    "num_patches": 49,  # Number of patches (7x7)
    "proj_dim": 384,  # Projection dimension
    "mlp_ratio": 4.0,  # MLP expansion ratio
    "attn_drop_rate": 0.1,  # Attention dropout rate
    "proj_drop_rate": 0.1,  # Projection dropout rate
    "freeze_backbone": False,  # Whether to freeze ResNet backbone
    "backbone": "r50",  # Backbone architecture ('r50' or 'r18')
    "pretrained": True,  # Use pretrained backbone weights
}

# Create model
model = build_model_no_extra_params(**config)
model = model.to(device)

# Count parameters
trainable_params, total_params = count_parameters(model)

print("🔍 Model Architecture Summary:")
print(f"   Model type: DuoFormer")
print(f"   Backbone: ResNet-{config['backbone'][1:]}")
print(f"   Number of scales: {config['num_layers']}")
print(f"   Embedding dimension: {config['embed_dim']}")
print(f"   Number of heads: {config['num_heads']}")
print(f"   Trainable parameters: {trainable_params:.2f}M")
print(f"   Total parameters: {total_params:.2f}M")

## 6️⃣ Training Pipeline

Let's implement a training pipeline with proper metrics tracking and visualization.


In [None]:
class Trainer:
    """Training pipeline for DuoFormer."""

    def __init__(
        self,
        model: nn.Module,
        device: torch.device,
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-4,
    ):
        self.model = model
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.AdamW(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay
        )
        self.history = {
            "train_loss": [],
            "train_acc": [],
            "val_loss": [],
            "val_acc": [],
        }

    def train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
        """Train for one epoch."""
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc="Training", leave=False)
        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(self.device), target.to(self.device)

            # Forward pass
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            # Metrics
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            # Update progress bar
            progress_bar.set_postfix(
                {"loss": f"{loss.item():.4f}", "acc": f"{100.*correct/total:.2f}%"}
            )

        avg_loss = total_loss / len(train_loader)
        accuracy = 100.0 * correct / total
        return avg_loss, accuracy

    def validate(self, val_loader: DataLoader) -> Tuple[float, float]:
        """Validate the model."""
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc="Validation", leave=False)
            for data, target in progress_bar:
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                progress_bar.set_postfix(
                    {"loss": f"{loss.item():.4f}", "acc": f"{100.*correct/total:.2f}%"}
                )

        avg_loss = total_loss / len(val_loader)
        accuracy = 100.0 * correct / total
        return avg_loss, accuracy

    def fit(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 10):
        """Train the model for multiple epochs."""
        best_val_acc = 0.0

        for epoch in range(epochs):
            print(f"\n📅 Epoch {epoch+1}/{epochs}")

            # Training
            train_loss, train_acc = self.train_epoch(train_loader)
            self.history["train_loss"].append(train_loss)
            self.history["train_acc"].append(train_acc)

            # Validation
            val_loss, val_acc = self.validate(val_loader)
            self.history["val_loss"].append(val_loss)
            self.history["val_acc"].append(val_acc)

            # Print metrics
            print(f"   Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"   Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(
                    {
                        "epoch": epoch,
                        "model_state_dict": self.model.state_dict(),
                        "optimizer_state_dict": self.optimizer.state_dict(),
                        "val_acc": val_acc,
                    },
                    "best_duoformer_model.pt",
                )
                print(f"   ✅ Best model saved! (Val Acc: {val_acc:.2f}%)")

        print(f"\n🎉 Training complete! Best validation accuracy: {best_val_acc:.2f}%")
        return self.history


print("✅ Trainer class created!")

## 7️⃣ Train the Model

Now let's train our DuoFormer model on the medical image dataset.


In [None]:
# Initialize trainer
trainer = Trainer(model=model, device=device, learning_rate=1e-4, weight_decay=1e-4)

# Train the model
print("🚀 Starting training...\n")
history = trainer.fit(
    train_loader=train_loader, val_loader=val_loader, epochs=5  # Reduced for demo
)

## 8️⃣ Visualize Training Progress

Let's plot the training and validation metrics to understand the model's learning progress.


In [None]:
def plot_training_history(history: dict):
    """Plot training and validation metrics."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Plot loss
    epochs = range(1, len(history["train_loss"]) + 1)
    ax1.plot(epochs, history["train_loss"], "b-", label="Training Loss", linewidth=2)
    ax1.plot(epochs, history["val_loss"], "r-", label="Validation Loss", linewidth=2)
    ax1.set_xlabel("Epoch", fontsize=12)
    ax1.set_ylabel("Loss", fontsize=12)
    ax1.set_title("Model Loss", fontsize=14, fontweight="bold")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot accuracy
    ax2.plot(epochs, history["train_acc"], "b-", label="Training Accuracy", linewidth=2)
    ax2.plot(epochs, history["val_acc"], "r-", label="Validation Accuracy", linewidth=2)
    ax2.set_xlabel("Epoch", fontsize=12)
    ax2.set_ylabel("Accuracy (%)", fontsize=12)
    ax2.set_title("Model Accuracy", fontsize=14, fontweight="bold")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Highlight best validation accuracy
    best_epoch = np.argmax(history["val_acc"])
    best_acc = history["val_acc"][best_epoch]
    ax2.scatter(best_epoch + 1, best_acc, color="green", s=100, zorder=5)
    ax2.annotate(
        f"Best: {best_acc:.2f}%",
        xy=(best_epoch + 1, best_acc),
        xytext=(best_epoch + 1, best_acc - 5),
        arrowprops=dict(arrowstyle="->", color="green"),
        fontsize=10,
    )

    plt.tight_layout()
    plt.show()


plot_training_history(history)

## 9️⃣ Model Evaluation on Test Set

Let's evaluate our trained model on the test set and compute various metrics.


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns


def evaluate_model(model, test_loader, device):
    """Comprehensive model evaluation."""
    model.eval()
    all_predictions = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in tqdm(test_loader, desc="Evaluating"):
            data, target = data.to(device), target.to(device)
            output = model(data)

            probs = torch.softmax(output, dim=1)
            _, predicted = output.max(1)

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
    all_probs = np.array(all_probs)

    # Calculate metrics
    accuracy = (all_predictions == all_targets).mean() * 100

    print("\n📊 Test Set Performance:")
    print(f"   Overall Accuracy: {accuracy:.2f}%")

    # Classification report
    print("\n📋 Classification Report:")
    class_names = [f"Class {i}" for i in range(10)]
    print(
        classification_report(
            all_targets, all_predictions, target_names=class_names, digits=3
        )
    )

    # Confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.title("Confusion Matrix", fontsize=14, fontweight="bold")
    plt.xlabel("Predicted Label", fontsize=12)
    plt.ylabel("True Label", fontsize=12)
    plt.tight_layout()
    plt.show()

    return accuracy, all_predictions, all_targets, all_probs


# Load best model
checkpoint = torch.load("best_duoformer_model.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

# Evaluate
test_acc, predictions, targets, probabilities = evaluate_model(
    model, test_loader, device
)

## 🔟 Feature Visualization

Let's visualize the learned features and attention maps to understand what the model has learned.


In [None]:
def visualize_attention_maps(model, image_tensor, device):
    """Visualize attention maps from the model."""
    model.eval()

    # Hook to capture attention weights
    attention_weights = []

    def hook_fn(module, input, output):
        if hasattr(output, "detach"):
            attention_weights.append(output.detach())

    # Register hooks on attention layers
    hooks = []
    for name, module in model.named_modules():
        if "attn" in name.lower() and hasattr(module, "forward"):
            hook = module.register_forward_hook(hook_fn)
            hooks.append(hook)
            if len(hooks) >= 3:  # Limit to 3 attention layers
                break

    # Forward pass
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        _ = model(image_tensor)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    # Visualize
    if attention_weights:
        fig, axes = plt.subplots(1, min(3, len(attention_weights)), figsize=(15, 5))
        if len(attention_weights) == 1:
            axes = [axes]

        for idx, attn in enumerate(attention_weights[:3]):
            if attn.dim() >= 2:
                # Take mean across heads if multi-head attention
                if attn.dim() > 2:
                    attn_map = attn.mean(dim=0 if attn.shape[0] > 1 else 1)
                else:
                    attn_map = attn

                # Ensure 2D for visualization
                if attn_map.dim() > 2:
                    attn_map = attn_map.view(attn_map.shape[0], -1)

                axes[idx].imshow(
                    attn_map.cpu().numpy(), cmap="hot", interpolation="nearest"
                )
                axes[idx].set_title(f"Attention Layer {idx+1}", fontsize=10)
                axes[idx].axis("off")

        plt.suptitle("Attention Maps Visualization", fontsize=14, fontweight="bold")
        plt.tight_layout()
        plt.show()
    else:
        print("No attention weights captured.")


# Get a sample image
sample_image, sample_label = next(iter(test_loader))
visualize_attention_maps(model, sample_image[0], device)

## 1️⃣1️⃣ Model Inference

Let's demonstrate how to use the trained model for inference on new medical images.


In [None]:
def predict_single_image(model, image_tensor, device, class_names=None):
    """Make prediction on a single image."""
    model.eval()

    with torch.no_grad():
        # Prepare image
        if image_tensor.dim() == 3:
            image_tensor = image_tensor.unsqueeze(0)
        image_tensor = image_tensor.to(device)

        # Forward pass
        output = model(image_tensor)
        probabilities = torch.softmax(output, dim=1)

        # Get prediction
        confidence, predicted_class = probabilities.max(1)

    predicted_class = predicted_class.item()
    confidence = confidence.item()

    # Get top 3 predictions
    top3_prob, top3_classes = probabilities[0].topk(3)

    if class_names is None:
        class_names = [f"Class {i}" for i in range(probabilities.shape[1])]

    return {
        "predicted_class": predicted_class,
        "class_name": class_names[predicted_class],
        "confidence": confidence,
        "top3": [
            (class_names[c.item()], p.item()) for c, p in zip(top3_classes, top3_prob)
        ],
    }


# Test on sample images
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.ravel()

for i in range(6):
    image, true_label = test_dataset[i]

    # Make prediction
    result = predict_single_image(model, image, device)

    # Visualize
    if isinstance(image, torch.Tensor):
        img_display = image.numpy().transpose(1, 2, 0)
    else:
        img_display = image

    img_display = (
        np.clip(img_display, 0, 1) if img_display.max() <= 1 else img_display / 255.0
    )

    axes[i].imshow(img_display)
    axes[i].set_title(
        f"True: Class {true_label}\n"
        f"Pred: {result['class_name']} ({result['confidence']:.2%})",
        fontsize=10,
        color="green" if result["predicted_class"] == true_label else "red",
    )
    axes[i].axis("off")

plt.suptitle("Model Predictions on Test Images", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

## 1️⃣2️⃣ Model Export and Deployment

Finally, let's prepare the model for deployment.


In [None]:
# Export model to ONNX format for deployment
def export_to_onnx(model, save_path="duoformer_model.onnx"):
    """Export model to ONNX format."""
    model.eval()

    # Create dummy input
    dummy_input = torch.randn(1, 3, 224, 224, device=device)

    # Export
    torch.onnx.export(
        model,
        dummy_input,
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    )
    print(f"✅ Model exported to {save_path}")
    return save_path


# Export the model
# onnx_path = export_to_onnx(model)  # Uncomment to export


# Save model with metadata
model_metadata = {
    "model_state_dict": model.state_dict(),
    "config": config,
    "input_size": (3, 224, 224),
    "num_classes": 10,
    "normalization": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
    "performance": {
        "test_accuracy": test_acc,
        "best_val_accuracy": max(history["val_acc"]),
    },
}

torch.save(model_metadata, "duoformer_final_model.pt")
print("✅ Model saved with metadata for deployment!")

## 📝 Summary and Conclusions

In this notebook, we demonstrated:

1. **🔧 Setup**: Configured the DuoFormer architecture with multi-scale attention
2. **📊 Data**: Created a synthetic medical image dataset with various patterns
3. **🧠 Model**: Initialized DuoFormer with ResNet backbone and Vision Transformer
4. **🎯 Training**: Trained the model with proper metrics tracking
5. **📈 Evaluation**: Evaluated performance with confusion matrix and classification report
6. **🔍 Visualization**: Visualized attention maps and learned features
7. **🚀 Deployment**: Prepared the model for production deployment

### Key Takeaways

- DuoFormer effectively combines CNN and Transformer architectures
- Multi-scale attention enables better feature extraction at different resolutions
- The model is particularly suitable for medical imaging tasks
- Flexible configuration allows adaptation to various dataset requirements

### Next Steps

1. **Fine-tune** on your specific medical imaging dataset
2. **Experiment** with different backbone architectures (ResNet-18, ResNet-101)
3. **Adjust** the number of scales based on your image resolution
4. **Implement** domain-specific augmentations for medical images
5. **Deploy** using ONNX or TorchScript for production inference

---

For questions or contributions, please refer to the [DuoFormer GitHub repository](https://github.com/duoformer/duoformer).
