<a href="https://colab.research.google.com/github/VRSFXECE/VRS/blob/main/swtransret.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install SwinImageProcessor
#!pip install AutoImageProcessor
import torch.nn as nn
import torch.nn.functional as F
from transformers import SwinForImageClassification, SwinImageProcessor
from transformers import AutoImageProcessor, AutoModelForImageClassification
from typing import List, Optional, Dict, Tuple, Union
import warnings
from collections import OrderedDict

class SystemicDiseaseSwinTransformer(nn.Module):
    """
    Swin Transformer for Systemic Disease Detection from Retinal Fundus Images

    Advantages over ViT for medical imaging:
    - Hierarchical feature extraction
    - Shifted window attention
    - Better at capturing local and global features
    - Often outperforms ViT on medical image tasks
    """
    def __init__(self,
                 diseases: List[str],
                 model_name: str = "microsoft/swin-tiny-patch4-window7-224",
                 pretrained_path: Optional[str] = None,
                 dropout_rate: float = 0.4,
                 use_multi_label: bool = True,
                 use_auxiliary_classifier: bool = False):
        """
        Initialize Swin Transformer for disease detection

        Args:
            diseases: List of systemic diseases to detect
            model_name: Pretrained Swin model name
            pretrained_path: Path to custom pretrained weights
            dropout_rate: Dropout probability
            use_multi_label: Multi-label classification (True) or multi-class (False)
            use_auxiliary_classifier: Use auxiliary classifiers from intermediate stages
        """
        super().__init__()

        self.diseases = diseases
        self.num_diseases = len(diseases)
        self.model_name = model_name
        self.use_multi_label = use_multi_label
        self.use_auxiliary_classifier = use_auxiliary_classifier

        print(f"Initializing Swin Transformer for {self.num_diseases} systemic diseases:")
        for i, disease in enumerate(diseases):
            print(f"  {i}: {disease}")

        # Load pre-trained Swin Transformer
        try:
            self.swin_model = SwinForImageClassification.from_pretrained(
                model_name,
                num_labels=self.num_diseases,
                ignore_mismatched_sizes=True
            )
            print(f"‚úì Loaded pretrained Swin model: {model_name}")
        except Exception as e:
            warnings.warn(f"Could not load pretrained model: {e}")
            from transformers import SwinConfig
            config = SwinConfig.from_pretrained(model_name)
            config.num_labels = self.num_diseases
            self.swin_model = SwinForImageClassification(config)
            print(f"‚úì Initialized new Swin model with config: {model_name}")

        # Get model configuration
        self.config = self.swin_model.config
        hidden_size = self.config.hidden_size

        # Store original classifier for reference
        self.original_classifier = self.swin_model.classifier

        # Replace classification head with custom head
        print(f"Original hidden size: {hidden_size}")

        # Enhanced classification head for medical imaging
        self.classifier = self._build_enhanced_classifier(
            hidden_size=hidden_size,
            dropout_rate=dropout_rate
        )

        # Replace model's classifier
        self.swin_model.classifier = nn.Identity()  # Remove original

        # Optional: Auxiliary classifiers from different stages
        if use_auxiliary_classifier:
            self.auxiliary_classifiers = self._build_auxiliary_classifiers()
        else:
            self.auxiliary_classifiers = None

        # Load custom pretrained weights if provided
        if pretrained_path is not None:
            self.load_custom_weights(pretrained_path)

        # Initialize image processor for fundus images
        self.processor = self._initialize_image_processor()

        # Setup disease mapping
        self._setup_disease_mapping()

        # Initialize loss function
        self._initialize_loss_function()

        # Count parameters
        self._count_parameters()

    def _build_enhanced_classifier(self, hidden_size: int, dropout_rate: float) -> nn.Module:
        """Build enhanced classifier for disease detection"""
        print(f"Building enhanced classifier with dropout={dropout_rate}")

        # Multi-layer classifier with batch normalization
        classifier_layers = [
            nn.LayerNorm(hidden_size),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_size // 2),
            nn.Dropout(dropout_rate / 2),
            nn.Linear(hidden_size // 2, self.num_diseases)
        ]

        # Add attention pooling if needed
        if self.num_diseases > 10:  # For many diseases, add more capacity
            classifier_layers.insert(2, nn.Linear(hidden_size // 2, hidden_size // 4))
            classifier_layers.insert(3, nn.GELU())
            classifier_layers.insert(4, nn.Dropout(dropout_rate / 3))
            classifier_layers.insert(5, nn.Linear(hidden_size // 4, hidden_size // 2))
            classifier_layers.insert(6, nn.GELU())

        classifier = nn.Sequential(*classifier_layers)

        # Initialize weights
        self._initialize_classifier_weights(classifier)

        return classifier

    def _build_auxiliary_classifiers(self) -> nn.ModuleDict:
        """Build auxiliary classifiers from intermediate stages"""
        print("Building auxiliary classifiers for multi-stage features")

        # Swin has 4 stages with different feature dimensions
        stage_dims = [
            self.config.hidden_size,  # Stage 1
            self.config.hidden_size * 2,  # Stage 2
            self.config.hidden_size * 4,  # Stage 3
            self.config.hidden_size * 8,  # Stage 4
        ]

        auxiliary_classifiers = nn.ModuleDict()

        for i, dim in enumerate(stage_dims[:-1]):  # Exclude last stage (main classifier)
            aux_classifier = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.LayerNorm(dim),
                nn.Linear(dim, dim // 4),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(dim // 4, self.num_diseases)
            )
            auxiliary_classifiers[f'stage_{i+1}'] = aux_classifier

        return auxiliary_classifiers

    def _initialize_classifier_weights(self, classifier: nn.Module):
        """Initialize classifier weights properly"""
        for module in classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def _initialize_image_processor(self):
        """Initialize image processor for retinal fundus images"""
        try:
            processor = SwinImageProcessor.from_pretrained(self.model_name)
            print(f"‚úì Loaded SwinImageProcessor for {self.model_name}")
        except:
            try:
                processor = AutoImageProcessor.from_pretrained(self.model_name)
                print(f"‚úì Loaded AutoImageProcessor for {self.model_name}")
            except:
                processor = SwinImageProcessor(
                    size={"height": 224, "width": 224},
                    do_resize=True,
                    do_rescale=True,
                    do_normalize=True,
                    image_mean=[0.485, 0.456, 0.406],
                    image_std=[0.229, 0.224, 0.225],
                )
                print("‚úì Created default SwinImageProcessor")

        # Medical imaging adjustments
        processor.do_rescale = True
        processor.rescale_factor = 1.0 / 255.0

        return processor

    def _setup_disease_mapping(self):
        """Setup disease ID mappings"""
        self.disease_to_id = {disease: i for i, disease in enumerate(self.diseases)}
        self.id_to_disease = {i: disease for i, disease in enumerate(self.diseases)}

        # Update model config
        self.swin_model.config.id2label = self.id_to_disease
        self.swin_model.config.label2id = self.disease_to_id

    def _initialize_loss_function(self):
        """Initialize appropriate loss function"""
        if self.use_multi_label:
            self.criterion = nn.BCEWithLogitsLoss()
            print("Loss: BCEWithLogitsLoss (multi-label)")
        else:
            self.criterion = nn.CrossEntropyLoss()
            print("Loss: CrossEntropyLoss (multi-class)")

    def _count_parameters(self):
        """Count and display model parameters"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print(f"\nParameter Count:")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
        print(f"  Trainable %: {100 * trainable_params / total_params:.2f}%")

    def forward(self, pixel_values: torch.Tensor) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass through Swin Transformer

        Returns:
            - If not using auxiliary: logits tensor
            - If using auxiliary: dict with main and auxiliary logits
        """
        # Get features from Swin backbone
        outputs = self.swin_model.swin(
            pixel_values=pixel_values,
            output_hidden_states=self.use_auxiliary_classifier,
            output_attentions=False,
            return_dict=True
        )

        # Get pooled output (usually from last hidden state)
        pooled_output = outputs.pooler_output

        # Main classifier
        main_logits = self.classifier(pooled_output)

        if not self.use_auxiliary_classifier:
            return main_logits

        # Get auxiliary outputs if enabled
        hidden_states = outputs.hidden_states

        auxiliary_outputs = {}
        for i, (stage_name, aux_classifier) in enumerate(self.auxiliary_classifiers.items()):
            # Get hidden state for this stage (skip first which is embeddings)
            stage_features = hidden_states[i + 1]

            # For Swin, we need to reshape from (B, L, C) to (B, C, H, W) for pooling
            batch_size, seq_len, channels = stage_features.shape
            # Assuming square features: sqrt(seq_len) should be integer
            h = w = int(seq_len ** 0.5)
            stage_features = stage_features.transpose(1, 2).reshape(batch_size, channels, h, w)

            aux_logits = aux_classifier(stage_features)
            auxiliary_outputs[stage_name] = aux_logits

        return {
            'main': main_logits,
            'auxiliary': auxiliary_outputs
        }

    def predict_proba(self, pixel_values: torch.Tensor,
                     temperature: float = 1.0) -> torch.Tensor:
        """
        Get probability predictions with optional temperature scaling

        Args:
            pixel_values: Input images
            temperature: Temperature for softmax (for calibration)

        Returns:
            Probability tensor
        """
        with torch.no_grad():
            output = self.forward(pixel_values)

            if isinstance(output, dict):
                logits = output['main']
            else:
                logits = output

            if self.use_multi_label:
                # Multi-label: independent sigmoid
                probs = torch.sigmoid(logits / temperature)
            else:
                # Multi-class: softmax
                probs = F.softmax(logits / temperature, dim=-1)

            return probs

    def preprocess_fundus_image(self, image: Union[torch.Tensor, 'Image.Image', np.ndarray],
                               apply_clahe: bool = False) -> Dict[str, torch.Tensor]:
        """
        Preprocess retinal fundus image for Swin Transformer

        Args:
            image: Input image (PIL, Tensor, or numpy array)
            apply_clahe: Apply CLAHE enhancement (for low contrast fundus images)

        Returns:
            Dictionary with processed pixel values
        """
        from PIL import Image
        import numpy as np

        # Convert to PIL Image if needed
        if isinstance(image, torch.Tensor):
            if image.dim() == 3:
                image = image.permute(1, 2, 0) if image.shape[0] in [1, 3] else image
            image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
        elif isinstance(image, np.ndarray):
            if image.max() <= 1.0:
                image = (image * 255).astype(np.uint8)
            image = Image.fromarray(image)

        # Optional: Apply CLAHE for contrast enhancement
        if apply_clahe:
            try:
                image = self._apply_clahe(image)
            except:
                warnings.warn("CLAHE failed, using original image")

        # Apply Swin Transformer preprocessing
        inputs = self.processor(images=image, return_tensors="pt")

        # Additional medical-specific normalization
        if hasattr(self, 'normalize_fundus'):
            inputs['pixel_values'] = self.normalize_fundus(inputs['pixel_values'])

        return inputs

    def _apply_clahe(self, image: 'Image.Image') -> 'Image.Image':
        """Apply CLAHE contrast enhancement"""
        import cv2
        import numpy as np

        # Convert PIL to numpy
        img_array = np.array(image)

        # Convert to LAB color space
        if len(img_array.shape) == 3 and img_array.shape[2] == 3:
            lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
            l, a, b = cv2.split(lab)

            # Apply CLAHE to L channel
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            cl = clahe.apply(l)

            # Merge channels
            limg = cv2.merge((cl, a, b))

            # Convert back to RGB
            enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
        else:
            # Grayscale image
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            enhanced = clahe.apply(img_array)

        return Image.fromarray(enhanced)

    def freeze_backbone_layers(self, unfreeze_last_n: int = 2):
        """
        Freeze Swin Transformer backbone layers

        Args:
            unfreeze_last_n: Number of Swin stages to unfreeze (1-4)
                            1: Only last stage trainable
                            4: All stages trainable
        """
        unfreeze_last_n = max(1, min(4, unfreeze_last_n))

        # Freeze entire model first
        for param in self.swin_model.parameters():
            param.requires_grad = False

        # Unfreeze specified number of stages
        stages_to_unfreeze = list(range(4 - unfreeze_last_n, 4))

        for stage_idx in stages_to_unfreeze:
            # Unfreeze attention layers in this stage
            stage_layers = self.swin_model.swin.encoder.layers[stage_idx].blocks
            for block in stage_layers:
                for param in block.parameters():
                    param.requires_grad = True

            # Unfreeze patch merging if not last stage
            if stage_idx < 3:  # Swin has 3 patch merging layers
                patch_merge = self.swin_model.swin.encoder.layers[stage_idx].downsample
                if patch_merge is not None:
                    for param in patch_merge.parameters():
                        param.requires_grad = True

        # Always unfreeze classifier
        for param in self.classifier.parameters():
            param.requires_grad = True

        if self.auxiliary_classifiers is not None:
            for param in self.auxiliary_classifiers.parameters():
                param.requires_grad = True

        print(f"Unfroze last {unfreeze_last_n} Swin stages (stages {stages_to_unfreeze})")
        print("Classifier and auxiliary heads are trainable")

    def get_attention_maps(self, pixel_values: torch.Tensor,
                          stage_idx: int = -1,
                          head_idx: Optional[int] = None) -> torch.Tensor:
        """
        Extract attention maps from Swin Transformer

        Args:
            pixel_values: Input images
            stage_idx: Stage index (0-3, -1 for last)
            head_idx: Specific attention head (None for average)

        Returns:
            Attention maps
        """
        self.eval()

        if stage_idx == -1:
            stage_idx = len(self.swin_model.swin.encoder.layers) - 1

        with torch.no_grad():
            # Get attention outputs
            outputs = self.swin_model.swin(
                pixel_values=pixel_values,
                output_attentions=True,
                return_dict=True
            )

            # Get attention from specified stage
            attentions = outputs.attentions[stage_idx]  # [batch, num_heads, window_num, window_size, window_size]

            if head_idx is not None:
                # Specific head
                attention_maps = attentions[:, head_idx]
            else:
                # Average across heads
                attention_maps = attentions.mean(dim=1)

            return attention_maps

    def load_custom_weights(self, path: str, strict: bool = False):
        """Load custom pretrained weights"""
        try:
            checkpoint = torch.load(path, map_location='cpu')

            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            # Filter out incompatible classifier weights
            model_state_dict = self.state_dict()
            filtered_dict = {}

            for k, v in state_dict.items():
                # Handle classifier key differences
                if 'classifier' in k and 'swin_model.classifier' not in k:
                    # Adjust key names if needed
                    new_key = k.replace('classifier', 'swin_model.classifier')
                    if new_key in model_state_dict:
                        filtered_dict[new_key] = v
                elif k in model_state_dict and model_state_dict[k].shape == v.shape:
                    filtered_dict[k] = v

            # Load filtered weights
            self.load_state_dict(filtered_dict, strict=strict)

            print(f"‚úì Loaded custom weights from {path}")
            print(f"  Loaded {len(filtered_dict)}/{len(state_dict)} parameters")

        except Exception as e:
            warnings.warn(f"Failed to load custom weights: {e}")

    def get_feature_maps(self, pixel_values: torch.Tensor,
                        stage_idx: int = -1) -> torch.Tensor:
        """
        Extract feature maps from specific Swin stage

        Useful for visualization and interpretability
        """
        self.eval()

        with torch.no_grad():
            outputs = self.swin_model.swin(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True
            )

            hidden_states = outputs.hidden_states

            if stage_idx == -1:
                # Last hidden state
                features = hidden_states[-1]
            else:
                features = hidden_states[stage_idx]

            return features

    def summary(self):
        """Print model summary"""
        print("\n" + "="*70)
        print("SYSTEMIC DISEASE SWIN TRANSFORMER SUMMARY")
        print("="*70)
        print(f"Model: {self.model_name}")
        print(f"Diseases: {self.num_diseases}")
        print(f"Multi-label: {self.use_multi_label}")
        print(f"Auxiliary classifiers: {self.use_auxiliary_classifier}")

        # Print layer information
        print("\nSwin Transformer Architecture:")
        print(f"  Hidden size: {self.config.hidden_size}")
        print(f"  Layers: {self.config.num_hidden_layers}")
        print(f"  Heads: {self.config.num_attention_heads}")
        print(f"  Window size: {self.config.window_size}")
        print(f"  MLP ratio: {self.config.mlp_ratio}")

        # Print classifier info
        print(f"\nClassifier Architecture:")
        print(f"  Input dim: {self.config.hidden_size}")
        print(f"  Output dim: {self.num_diseases}")
        print(f"  Dropout: {self.classifier[1].p if len(self.classifier) > 1 else 'N/A'}")

        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print(f"\nParameters:")
        print(f"  Total: {total_params:,}")
        print(f"  Trainable: {trainable_params:,}")
        print(f"  Percentage trainable: {100*trainable_params/total_params:.2f}%")
        print("="*70)


def create_swin_disease_detector(
    diseases: List[str],
    model_variant: str = "tiny",
    use_pretrained: bool = True,
    device: Optional[str] = None,
    **kwargs
) -> SystemicDiseaseSwinTransformer:
    """
    Factory function to create Swin Transformer disease detector

    Args:
        diseases: List of systemic diseases
        model_variant: 'tiny', 'small', 'base', or 'large'
        use_pretrained: Use pretrained weights
        device: Target device
        **kwargs: Additional arguments for SystemicDiseaseSwinTransformer

    Returns:
        Initialized Swin Transformer model
    """
    # Map model variants to HuggingFace model names
    model_map = {
        'tiny': 'microsoft/swin-tiny-patch4-window7-224',
        'small': 'microsoft/swin-small-patch4-window7-224',
        'base': 'microsoft/swin-base-patch4-window7-224',
        'large': 'microsoft/swin-large-patch4-window7-224',
        'v2_tiny': 'microsoft/swinv2-tiny-patch4-window8-256',
        'v2_base': 'microsoft/swinv2-base-patch4-window12to24-192to384-22kto1k',
    }

    if model_variant not in model_map:
        print(f"Warning: Model variant '{model_variant}' not found, using 'tiny'")
        model_variant = 'tiny'

    model_name = model_map[model_variant]

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print(f"üöÄ Creating Swin Transformer ({model_variant}) for Systemic Disease Detection")
    print(f"üìä Device: {device}")
    print(f"üéØ Diseases: {', '.join(diseases)}")

    if use_pretrained:
        model = SystemicDiseaseSwinTransformer(
            diseases=diseases,
            model_name=model_name,
            **kwargs
        )
    else:
        # Initialize from scratch (not recommended)
        from transformers import SwinConfig
        config = SwinConfig.from_pretrained(model_name)
        config.num_labels = len(diseases)

        swin_model = SwinForImageClassification(config)

        model = SystemicDiseaseSwinTransformer(diseases=diseases, **kwargs)
        model.swin_model = swin_model
        model.swin_model.classifier = nn.Identity()

    # Move to device
    model = model.to(device)

    # Print summary
    model.summary()

    return model


# Training utilities for Swin Transformer
class SwinDiseaseTrainer:
    """Training utilities for Swin Transformer disease detection"""

    @staticmethod
    def train_epoch(model, dataloader, optimizer, device,
                   scheduler=None, grad_clip: float = 1.0,
                   aux_weight: float = 0.3):
        """Train for one epoch"""
        model.train()
        total_loss = 0

        for batch_idx, (images, labels) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Calculate loss
            if isinstance(outputs, dict):
                # With auxiliary classifiers
                main_loss = SwinDiseaseTrainer._calculate_loss(
                    outputs['main'], labels, model.criterion, model.use_multi_label
                )

                aux_loss = 0
                if 'auxiliary' in outputs:
                    for aux_logits in outputs['auxiliary'].values():
                        aux_loss += SwinDiseaseTrainer._calculate_loss(
                            aux_logits, labels, model.criterion, model.use_multi_label
                        )
                    aux_loss = aux_loss / len(outputs['auxiliary'])

                loss = main_loss + aux_weight * aux_loss
            else:
                # Without auxiliary
                loss = SwinDiseaseTrainer._calculate_loss(
                    outputs, labels, model.criterion, model.use_multi_label
                )

            # Backward pass
            loss.backward()

            # Gradient clipping
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            if scheduler is not None:
                scheduler.step()

            total_loss += loss.item()

            # Progress update
            if (batch_idx + 1) % 10 == 0:
                print(f"  Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")

        return total_loss / len(dataloader)

    @staticmethod
    def _calculate_loss(logits, labels, criterion, is_multi_label):
        """Calculate loss based on task type"""
        if is_multi_label:
            return criterion(logits, labels.float())
        else:
            return criterion(logits, labels)

    @staticmethod
    def evaluate(model, dataloader, device,
                threshold: float = 0.5):
        """Evaluate model performance"""
        model.eval()
        all_preds = []
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)

                # Get predictions
                probs = model.predict_proba(images)

                if model.use_multi_label:
                    preds = (probs > threshold).float()
                else:
                    preds = torch.argmax(probs, dim=-1)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())
                all_probs.append(probs.cpu())

        preds_tensor = torch.cat(all_preds, dim=0)
        labels_tensor = torch.cat(all_labels, dim=0)
        probs_tensor = torch.cat(all_probs, dim=0)

        return {
            'predictions': preds_tensor,
            'labels': labels_tensor,
            'probabilities': probs_tensor
        }


# Example usage with medical dataset
if __name__ == "__main__":
    # Define systemic diseases detectable from retinal fundus images
    systemic_diseases = [
        "Diabetes_Mellitus",        # Diabetic retinopathy
        "Hypertension",             # Hypertensive retinopathy
        "Anemia",                   # Retinal pallor, hemorrhages
        "Atherosclerosis",          # AV nicking, vascular changes
        "Sickle_Cell_Disease",      # Sea fan neovascularization
        "HIV_Retinopathy",          # Cotton wool spots, hemorrhages
        "Multiple_Sclerosis",       # Optic neuritis
        "Leukemia",                 # Roth spots, retinal infiltrates
        "Sarcoidosis",              # Candle wax drippings
        "Systemic_Lupus",           # Retinal vasculitis
        "Carotid_Artery_Disease",   # Ocular ischemic syndrome
        "Hyperthyroidism",          # Thyroid eye disease signs
        "Renal_Disease",            # Hypertensive changes
        "Blood_Disorders",          # Various hemorrhagic signs
    ]

    # Create Swin Transformer model
    print("\n" + "="*80)
    print("CREATING SWIN TRANSFORMER FOR SYSTEMIC DISEASE DETECTION")
    print("="*80)

    model = create_swin_disease_detector(
        diseases=systemic_diseases,
        model_variant="base",  # Options: tiny, small, base, large, v2_tiny, v2_base
        use_pretrained=True,
        dropout_rate=0.4,
        use_multi_label=True,  # Multiple diseases can co-exist
        use_auxiliary_classifier=True,  # Use features from multiple stages
    )

    # Configure for fine-tuning
    print("\nüîß Configuring for fine-tuning...")
    model.freeze_backbone_layers(unfreeze_last_n=2)  # Unfreeze last 2 stages

    # Test with dummy fundus image
    print("\nüß™ Testing with dummy fundus image batch...")
    batch_size = 4
    dummy_fundus = torch.randn(batch_size, 3, 224, 224)
    device = next(model.parameters()).device
    dummy_fundus = dummy_fundus.to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(dummy_fundus)

        if isinstance(outputs, dict):
            logits = outputs['main']
            print(f"‚úì Main logits shape: {logits.shape}")
            if 'auxiliary' in outputs:
                for stage, aux_logits in outputs['auxiliary'].items():
                    print(f"  {stage} logits shape: {aux_logits.shape}")
        else:
            logits = outputs
            print(f"‚úì Logits shape: {logits.shape}")

        # Get probabilities
        probs = model.predict_proba(dummy_fundus)
        print(f"‚úì Probabilities shape: {probs.shape}")

        # Example interpretation
        print(f"\nüìä Example prediction for first image:")
        if model.use_multi_label:
            threshold = 0.3  # Lower threshold for medical screening
            pred_mask = (probs[0] > threshold).cpu().numpy()

            print(f"  Threshold: {threshold}")
            print(f"  Predicted diseases:")
            for i, (disease, pred) in enumerate(zip(systemic_diseases, pred_mask)):
                if pred:
                    prob = probs[0, i].item()
                    print(f"    ‚Ä¢ {disease}: {prob:.3f}")
        else:
            pred_idx = torch.argmax(probs[0]).item()
            pred_prob = probs[0, pred_idx].item()
            print(f"  Predicted: {systemic_diseases[pred_idx]} ({pred_prob:.3f})")

    # Test attention maps
    print("\nüëÅÔ∏è Testing attention map extraction...")
    attention_maps = model.get_attention_maps(dummy_fundus[:1], stage_idx=-1)
    print(f"  Attention maps shape: {attention_maps.shape}")

    # Save model
    print("\nüíæ Saving model configuration...")
    model_config = {
        'diseases': systemic_diseases,
        'model_name': model.model_name,
        'num_diseases': model.num_diseases,
        'use_multi_label': model.use_multi_label,
        'use_auxiliary_classifier': model.use_auxiliary_classifier,
        'input_size': 224,
        'frozen_stages': 'first 2 stages frozen',
        'timestamp': '2024'
    }

    # Save full model
    save_path = "swin_systemic_disease_detector.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': model_config,
        'swin_config': model.config,
    }, save_path)

    print(f"‚úì Model saved to {save_path}")
    print("="*80)

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

To install Python libraries in a Colab notebook, you can use `!pip install <package_name>`. For example, to install `numpy`:

In [2]:
!pip install numpy



You can also specify a particular version of a library:

In [3]:
!pip install pandas==1.3.5

Collecting pandas==1.3.5
  Downloading pandas-1.3.5.tar.gz (4.7 MB)
[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/4.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m‚îÅ‚îÅ‚îÅ‚îÅ[0m[90m‚ï∫[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.5/4.7 MB[0m [31m15.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m4.7/4.7 MB[0m [31m60.4 MB/s[0m eta [36m0:00:01[0m[2K     [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m4.7/4.7 MB[0m [31m60.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î