<a href="https://colab.research.google.com/github/Junkai03/kat/blob/main/KAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvcc --version
!nvidia-smi


nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
Sun Nov 24 20:01:16 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8              11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                      

In [None]:
# Install the necessary libraries
!pip install timm
!pip install wandb

# Clone the project code
!git clone https://github.com/Adamdad/rational_kat_cu.git
%cd rational_kat_cu
!pip install -e .


fatal: destination path 'rational_kat_cu' already exists and is not an empty directory.
/content/rational_kat_cu
Obtaining file:///content/rational_kat_cu
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: kat_rational
  Attempting uninstall: kat_rational
    Found existing installation: kat_rational 0.3
    Uninstalling kat_rational-0.3:
      Successfully uninstalled kat_rational-0.3
  Running setup.py develop for kat_rational
Successfully installed kat_rational-0.3


In [None]:
# Test the model
from kat_rational import KAT_Group
print("KAT_Group imported successfully!")

KAT_Group imported successfully!


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  def backward(ctx, grad_output):


In [None]:
import torch
import torch.nn as nn
from katransformer import Block

In [None]:
class PatchEmbedding(nn.Module):
    """
    Patch Embedding for image input.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2

        # Conv2d for patch embedding
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))

    def forward(self, x):
        # Input: [Batch, Channels, Height, Width]
        x = self.proj(x)  # [Batch, Embed Dim, Num Patches Height, Num Patches Width]
        x = x.flatten(2).transpose(1, 2)  # [Batch, Num Patches, Embed Dim]
        x = x + self.pos_embed  # Add positional embedding
        return x


In [None]:
class StackBlocks(nn.Module):
    """
    Stack two Transformer Blocks for feature extraction.
    """
    def __init__(self, block, embed_dim=768, num_heads=8, mlp_ratio=4., proj_drop=0.1,
                 attn_drop=0.1, drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        # Initialize two stacked Blocks
        self.block1 = block(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            proj_drop=proj_drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        self.block2 = block(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            proj_drop=proj_drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

    def forward(self, x):
        # Pass input through the first Block
        x1 = self.block1(x)
        # Pass the output of the first Block to the second Block
        x2 = self.block2(x1)
        return x2


In [None]:
class FeatureFusion(nn.Module):
    """
    Feature fusion module for combining outputs from different Blocks.
    """
    def __init__(self, embed_dim, fusion_type='concat', num_blocks=2):
        super().__init__()
        self.fusion_type = fusion_type
        if fusion_type == 'concat':
            self.fc = nn.Linear(embed_dim * num_blocks, embed_dim)  # Reduce concatenated features
        elif fusion_type == 'weighted_sum':
            self.weights = nn.Parameter(torch.ones(num_blocks))  # Learnable weights for each block

    def forward(self, features):
        """
        Args:
            features: List of features from different Blocks, each with shape [Batch, Num Patches, Embed Dim].
        Returns:
            Fused features with shape [Batch, Num Patches, Embed Dim] or [Batch, Embed Dim].
        """
        if self.fusion_type == 'concat':
            # Concatenate features along the last dimension
            fused = torch.cat(features, dim=-1)  # [Batch, Num Patches, Embed Dim * Num Blocks]
            fused = self.fc(fused)  # [Batch, Num Patches, Embed Dim]
        elif self.fusion_type == 'weighted_sum':
            # Apply weights to each feature and sum them
            weights = torch.softmax(self.weights, dim=0)  # Normalize weights
            fused = sum(w * f for w, f in zip(weights, features))  # [Batch, Num Patches, Embed Dim]
        else:
            raise ValueError(f"Unsupported fusion_type: {self.fusion_type}")
        return fused


In [None]:
class StackBlocksWithFusion(nn.Module):
    """
    Stack Transformer Blocks and fuse their features.
    """
    def __init__(self, block, embed_dim=768, num_heads=8, mlp_ratio=4., proj_drop=0.1,
                 attn_drop=0.1, drop_path=0.1, fusion_type='concat', act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        # Initialize two Blocks
        self.block1 = block(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            proj_drop=proj_drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        self.block2 = block(
            dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            proj_drop=proj_drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        # Feature fusion module
        self.fusion = FeatureFusion(embed_dim, fusion_type=fusion_type, num_blocks=2)

    def forward(self, x):
        # Pass through Block 1
        feature1 = self.block1(x)
        # Pass through Block 2
        feature2 = self.block2(feature1)
        # Fuse features
        fused_features = self.fusion([feature1, feature2])
        return fused_features


In [None]:
class KAN(nn.Module):
    """
    Kolmogorov–Arnold Network (KAN) for classification.
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.1):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


In [None]:
class MedicalImageClassifier(nn.Module):
    """
    Medical Image Classifier using Patch Embedding, Transformer Blocks, Feature Fusion, and KAN.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=2,
                 embed_dim=768, num_heads=8, mlp_ratio=4., proj_drop=0.1,
                 attn_drop=0.1, drop_path=0.1, fusion_type='concat', act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        # Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)

        # Transformer Blocks with Feature Fusion
        self.stack_blocks = StackBlocksWithFusion(
            block=Block,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            proj_drop=proj_drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            fusion_type=fusion_type,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

        # Classification Head (KAN)
        self.classifier = KAN(in_features=embed_dim, hidden_features=embed_dim, out_features=num_classes)

    def forward(self, x):
        # Patch Embedding
        x = self.patch_embed(x)  # [Batch, Num Patches, Embed Dim]

        # Transformer Blocks with Fusion
        x = self.stack_blocks(x)  # [Batch, Num Patches, Embed Dim]

        # Global Average Pooling (Reduce patches)
        x = x.mean(dim=1)  # [Batch, Embed Dim]

        # Classification
        logits = self.classifier(x)  # [Batch, Num Classes]
        return logits


### Test the StackBlocks class

In [None]:
if __name__ == "__main__":
    # Initialize the StackBlocks model
    stack_model = StackBlocks(
        block=Block,
        embed_dim=768,
        num_heads=8,
        mlp_ratio=4.,
        proj_drop=0.1,
        attn_drop=0.1,
        drop_path=0.1,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    )

    # Move model to GPU if available
    if torch.cuda.is_available():
        stack_model = stack_model.cuda()

    # Create a dummy input tensor
    input_tensor = torch.randn(8, 196, 768)  # Batch size = 8, Patches = 196 (14x14), Embedding dim = 768

    # Move input to GPU if available
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Forward pass
    output = stack_model(input_tensor)
    print("Output shape:", output.shape)  # The output shape should be [8, 196, 768]


Output shape: torch.Size([8, 196, 768])


### 将 Patch Embedding 与 StackBlocks 连接，完整测试数据流

In [None]:
if __name__ == "__main__":
    # Patch Embedding
    patch_embed = PatchEmbedding(img_size=224, patch_size=16, in_chans=3, embed_dim=768)
    if torch.cuda.is_available():
        patch_embed = patch_embed.cuda()

    # StackBlocks
    stack_model = StackBlocks(
        block=Block,
        embed_dim=768,
        num_heads=8,
        mlp_ratio=4.,
        proj_drop=0.1,
        attn_drop=0.1,
        drop_path=0.1,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
    )
    if torch.cuda.is_available():
        stack_model = stack_model.cuda()

    # Input: Dummy image tensor
    input_image = torch.randn(8, 3, 224, 224)  # Batch of 8, RGB images, size 224x224
    if torch.cuda.is_available():
        input_image = input_image.cuda()

    # Forward pass through Patch Embedding
    patch_embeddings = patch_embed(input_image)  # [Batch, Num Patches, Embed Dim]
    print("Patch Embedding shape:", patch_embeddings.shape)

    # Forward pass through StackBlocks
    output = stack_model(patch_embeddings)
    print("StackBlocks output shape:", output.shape)


Patch Embedding shape: torch.Size([8, 196, 768])
StackBlocks output shape: torch.Size([8, 196, 768])


### Test Feature Fusion Module

In [None]:
if __name__ == "__main__":
    # Initialize Patch Embedding
    patch_embed = PatchEmbedding(img_size=224, patch_size=16, in_chans=3, embed_dim=768)
    if torch.cuda.is_available():
        patch_embed = patch_embed.cuda()

    # Initialize StackBlocks with Fusion
    stack_model = StackBlocksWithFusion(
        block=Block,
        embed_dim=768,
        num_heads=8,
        mlp_ratio=4.,
        proj_drop=0.1,
        attn_drop=0.1,
        drop_path=0.1,
        fusion_type='concat',  # Options: 'concat' or 'weighted_sum'
    )
    if torch.cuda.is_available():
        stack_model = stack_model.cuda()

    # Dummy input image tensor
    input_image = torch.randn(8, 3, 224, 224)  # Batch of 8, RGB images, size 224x224
    if torch.cuda.is_available():
        input_image = input_image.cuda()

    # Forward pass through Patch Embedding
    patch_embeddings = patch_embed(input_image)  # [Batch, Num Patches, Embed Dim]
    print("Patch Embedding shape:", patch_embeddings.shape)

    # Forward pass through StackBlocks with Fusion
    fused_features = stack_model(patch_embeddings)
    print("Fused Features shape:", fused_features.shape)  # Should be [Batch, Num Patches, Embed Dim]


Patch Embedding shape: torch.Size([8, 196, 768])
Fused Features shape: torch.Size([8, 196, 768])


### Test the whole model

In [None]:
if __name__ == "__main__":
    # Initialize the full model
    model = MedicalImageClassifier(
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=2,  # Binary classification (e.g., diseased vs healthy)
        embed_dim=768,
        num_heads=8,
        mlp_ratio=4.,
        proj_drop=0.1,
        attn_drop=0.1,
        drop_path=0.1,
        fusion_type='concat',  # Options: 'concat' or 'weighted_sum'
    )
    if torch.cuda.is_available():
        model = model.cuda()

    # Dummy input image tensor
    input_image = torch.randn(8, 3, 224, 224)  # Batch of 8, RGB images, size 224x224
    if torch.cuda.is_available():
        input_image = input_image.cuda()

    # Forward pass
    logits = model(input_image)
    print("Logits shape:", logits.shape)  # Output shape should be [8, Num Classes]


Logits shape: torch.Size([8, 2])
