In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
import math

In [4]:
""""

class SelectiveScan(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        
        # Parameters for state space model
        self.A = nn.Parameter(torch.randn(d_model, d_state, d_state) / d_state)
        self.B = nn.Parameter(torch.randn(d_model, d_state) / math.sqrt(d_state))
        self.C = nn.Parameter(torch.randn(d_model, d_state) / math.sqrt(d_state))
        
        # Convolution for local context
        self.conv = nn.Conv1d(
            in_channels=d_model,
            out_channels=d_model,
            kernel_size=d_conv,
            padding=d_conv-1,
            groups=d_model
        )
        
        # Delta and gamma parameters for selective scanning
        self.dt = nn.Parameter(torch.randn(d_model))
        self.gamma = nn.Parameter(torch.randn(d_model))

    def forward(self, x):
        B, L, D = x.shape
        
        # Local context processing
        x_conv = rearrange(x, 'b l d -> b d l')
        x_conv = self.conv(x_conv)[..., :L]
        x_conv = rearrange(x_conv, 'b d l -> b l d')
        
        # Initialize state
        h = torch.zeros(B, self.d_state, device=x.device)
        outputs = []
        
        # Selective scan
        for t in range(L):
            # Current input
            u = x[:, t]
            
            # Update state with selective scan
            delta = torch.sigmoid(self.dt)
            A_hat = torch.exp(self.A * delta.unsqueeze(-1).unsqueeze(-1))
            h = torch.einsum('bm,mdh->bh', u, A_hat) + h
            
            # Compute output
            y = torch.einsum('bh,mh->bm', h, self.C)
            y = y * torch.sigmoid(self.gamma).unsqueeze(0)
            
            outputs.append(y)
        
        outputs = torch.stack(outputs, dim=1)
        return outputs + x_conv

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.proj_in = nn.Linear(d_model, d_model * 2)
        self.selective_scan = SelectiveScan(d_model, d_state, d_conv)
        self.proj_out = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        residual = x
        x = self.norm(x)
        
        # Project and split into value and gating branches
        x = self.proj_in(x)
        v, g = x.chunk(2, dim=-1)
        
        # Apply selective scan to value branch
        v = self.selective_scan(v)
        
        # Apply gating
        x = v * torch.sigmoid(g)
        
        # Project out and add residual
        x = self.proj_out(x)
        return x + residual

class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.mha = nn.MultiheadAttention(channels, 8, batch_first=True)
        self.norm = nn.LayerNorm(channels)
        
    def forward(self, x):
        b, c, h, w = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        attn_out, _ = self.mha(x, x, x)
        x = x + attn_out
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        return x

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.mamba = MambaBlock(out_channels)
        self.attention = AttentionBlock(out_channels)
        self.fusion = nn.Conv2d(out_channels * 2, out_channels, 1)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
        
    def forward(self, x):
        x = self.conv(x)
        
        # Mamba path
        b, c, h, w = x.shape
        x_mamba = rearrange(x, 'b c h w -> b (h w) c')
        x_mamba = self.mamba(x_mamba)
        x_mamba = rearrange(x_mamba, 'b (h w) c -> b c h w', h=h, w=w)
        
        # Attention path
        x_attn = self.attention(x)
        
        # Fusion
        combined = torch.cat([x_mamba, x_attn], dim=1)
        out = self.fusion(combined)
        out = self.norm(out)
        out = self.act(out)
        return out

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
        self.conv = nn.Conv2d(out_channels * 2, out_channels, 3, padding=1)
        self.mamba = MambaBlock(out_channels)
        self.attention = AttentionBlock(out_channels)
        self.fusion = nn.Conv2d(out_channels * 2, out_channels, 1)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
        
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        
        # Mamba path
        b, c, h, w = x.shape
        x_mamba = rearrange(x, 'b c h w -> b (h w) c')
        x_mamba = self.mamba(x_mamba)
        x_mamba = rearrange(x_mamba, 'b (h w) c -> b c h w', h=h, w=w)
        
        # Attention path
        x_attn = self.attention(x)
        
        # Fusion
        combined = torch.cat([x_mamba, x_attn], dim=1)
        out = self.fusion(combined)
        out = self.norm(out)
        out = self.act(out)
        return out

class MambaAttentionUNET(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder path
        in_features = in_channels
        for feature in features:
            self.downs.append(DownBlock(in_features, feature))
            in_features = feature

        # Bottleneck
        self.bottleneck = DownBlock(features[-1], features[-1] * 2)

        # Decoder path
        for feature in reversed(features):
            self.ups.append(UpBlock(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        down_out = x
        for down in self.downs:
            down_out = down(down_out)
            skip_connections.append(down_out)
            down_out = self.pool(down_out)

        # Bottleneck
        x = self.bottleneck(down_out)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(len(self.ups)):
            skip = skip_connections[idx]
            x = self.ups[idx](x, skip)

        return self.final_conv(x)

class Trainer:
    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Batch: {batch_idx}, Loss: {loss.item():.6f}')
                
        return total_loss / len(train_loader)
    
    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                total_loss += loss.item()
                
        return total_loss / len(val_loader)

"""

'"\n\nclass SelectiveScan(nn.Module):\n    def __init__(self, d_model, d_state=16, d_conv=4):\n        super().__init__()\n        self.d_model = d_model\n        self.d_state = d_state\n        self.d_conv = d_conv\n        \n        # Parameters for state space model\n        self.A = nn.Parameter(torch.randn(d_model, d_state, d_state) / d_state)\n        self.B = nn.Parameter(torch.randn(d_model, d_state) / math.sqrt(d_state))\n        self.C = nn.Parameter(torch.randn(d_model, d_state) / math.sqrt(d_state))\n        \n        # Convolution for local context\n        self.conv = nn.Conv1d(\n            in_channels=d_model,\n            out_channels=d_model,\n            kernel_size=d_conv,\n            padding=d_conv-1,\n            groups=d_model\n        )\n        \n        # Delta and gamma parameters for selective scanning\n        self.dt = nn.Parameter(torch.randn(d_model))\n        self.gamma = nn.Parameter(torch.randn(d_model))\n\n    def forward(self, x):\n        B, L, D 

In [5]:
"""def test_model():
    # Create sample input
    batch_size = 2
    channels = 3
    height = width = 256
    
    # Initialize model
    model = MambaAttentionUNET(in_channels=channels, out_channels=1)
    x = torch.randn(batch_size, channels, height, width)
    
    # Forward pass
    output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    
    return model, output

if __name__ == "__main__":
    # Test the model
    print("Testing model...")
    model, output = test_model()
    print("Model test successful!")
    
    # Example of creating a training dataset
    print("\nCreating example dataset...")
    num_samples = 100
    input_data = torch.randn(num_samples, 3, 256, 256)
    target_data = torch.randn(num_samples, 1, 256, 256)
    
    # Create DataLoader
    from torch.utils.data import TensorDataset, DataLoader
    dataset = TensorDataset(input_data, target_data)
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Initialize trainer
    print("\nInitializing trainer...")
    trainer = Trainer(model, device='cuda' if torch.cuda.is_available() else 'cpu')
    
    # Train for a few epochs
    print("\nStarting training...")
    num_epochs = 3
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        avg_loss = trainer.train_epoch(train_loader)
        print(f"Average loss: {avg_loss:.6f}")"""

'def test_model():\n    # Create sample input\n    batch_size = 2\n    channels = 3\n    height = width = 256\n    \n    # Initialize model\n    model = MambaAttentionUNET(in_channels=channels, out_channels=1)\n    x = torch.randn(batch_size, channels, height, width)\n    \n    # Forward pass\n    output = model(x)\n    \n    print(f"Input shape: {x.shape}")\n    print(f"Output shape: {output.shape}")\n    \n    return model, output\n\nif __name__ == "__main__":\n    # Test the model\n    print("Testing model...")\n    model, output = test_model()\n    print("Model test successful!")\n    \n    # Example of creating a training dataset\n    print("\nCreating example dataset...")\n    num_samples = 100\n    input_data = torch.randn(num_samples, 3, 256, 256)\n    target_data = torch.randn(num_samples, 1, 256, 256)\n    \n    # Create DataLoader\n    from torch.utils.data import TensorDataset, DataLoader\n    dataset = TensorDataset(input_data, target_data)\n    train_loader = DataLoader(d

In [None]:


class EfficientAttentionBlock(nn.Module):
    def __init__(self, channels, reduction_factor=8):
        super().__init__()
        self.norm = nn.LayerNorm(channels)
        reduced_dim = max(32, channels // reduction_factor)
        self.qkv = nn.Linear(channels, reduced_dim * 3)
        self.proj = nn.Linear(reduced_dim, channels)
        self.chunk_size = 256  # Process attention in chunks
        
    def forward(self, x):
        b, c, h, w = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        
        # Reduce sequence length for efficiency
        qkv = self.qkv(x)
        chunk_size = min(self.chunk_size, qkv.shape[1])
        
        outputs = []
        for chunk_idx in range(0, x.shape[1], chunk_size):
            chunk_end = min(chunk_idx + chunk_size, x.shape[1])
            qkv_chunk = qkv[:, chunk_idx:chunk_end]
            
            q, k, v = qkv_chunk.chunk(3, dim=-1)
            
            # Scaled dot-product attention
            attn = (q @ k.transpose(-2, -1)) / math.sqrt(k.size(-1))
            attn = F.softmax(attn, dim=-1)
            chunk_output = attn @ v
            outputs.append(chunk_output)
        
        output = torch.cat(outputs, dim=1)
        output = self.proj(output)
        output = output + x  # Residual connection
        
        output = rearrange(output, 'b (h w) c -> b c h w', h=h, w=w)
        return output

class SimplifiedMambaBlock(nn.Module):
    def __init__(self, channels, d_state=16):
        super().__init__()
        self.norm = nn.LayerNorm(channels)
        self.conv1 = nn.Conv1d(channels, channels, 3, padding=1, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.activation = nn.GELU()
        
    def forward(self, x):
        b, c, h, w = x.shape
        residual = x
        
        # Reshape and normalize
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.norm(x)
        x = rearrange(x, 'b n c -> b c n')
        
        # Simplified SSM using convolutions
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        
        # Reshape back
        x = rearrange(x, 'b c n -> b n c')
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        
        return x + residual

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.mamba = SimplifiedMambaBlock(out_channels)
        self.attention = EfficientAttentionBlock(out_channels)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.mamba(x)
        x = self.attention(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
        self.conv = nn.Conv2d(out_channels * 2, out_channels, 3, padding=1)
        self.mamba = SimplifiedMambaBlock(out_channels)
        self.attention = EfficientAttentionBlock(out_channels)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.GELU()
        
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        x = self.mamba(x)
        x = self.attention(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class MambaAttentionUNET(nn.Module):
    def __init__(self, in_channels, out_channels, features=[32, 64, 128, 256]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part
        in_features = in_channels
        for feature in features:
            self.downs.append(DownBlock(in_features, feature))
            in_features = feature

        # Bottleneck
        self.bottleneck = DownBlock(features[-1], features[-1] * 2)

        # Up part
        for feature in reversed(features):
            self.ups.append(UpBlock(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        down_out = x
        for down in self.downs:
            down_out = down(down_out)
            skip_connections.append(down_out)
            down_out = self.pool(down_out)

        # Bottleneck
        x = self.bottleneck(down_out)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(len(self.ups)):
            skip = skip_connections[idx]
            x = self.ups[idx](x, skip)

        return self.final_conv(x)

def test_model(input_size=(64, 64)):
    # Create sample input with smaller dimensions
    batch_size = 2
    channels = 3
    height, width = input_size
    
    # Initialize model with reduced feature dimensions
    model = MambaAttentionUNET(
        in_channels=channels, 
        out_channels=1,
        features=[32, 64, 128, 256]  # Reduced feature dimensions
    )
    
    # Move model to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Create input tensor
    x = torch.randn(batch_size, channels, height, width).to(device)
    
    # Forward pass
    with torch.cuda.amp.autocast() if torch.cuda.is_available() else torch.no_grad():
        output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Model is on: {next(model.parameters()).device}")
    
    return model, output

if __name__ == "__main__":
    # Test with smaller input size first
    print("Testing model with small input...")
    model, output = test_model(input_size=(64, 64))
    print("Small input test successful!")
    
    # If successful, can try with larger input
    try:
        print("\nTesting model with medium input...")
        model, output = test_model(input_size=(128, 128))
        print("Medium input test successful!")
    except RuntimeError as e:
        print("Memory error with medium input, stick with smaller input size")
        print(f"Error: {e}")

Testing model with small input...
Input shape: torch.Size([2, 3, 64, 64])
Output shape: torch.Size([2, 1, 64, 64])
Model is on: cpu
Small input test successful!

Testing model with medium input...
Input shape: torch.Size([2, 3, 128, 128])
Output shape: torch.Size([2, 1, 128, 128])
Model is on: cpu
Medium input test successful!


In [7]:
import torch
from torchsummary import summary
from torchviz import make_dot

# Function to summarize the model
def model_summary(model, input_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    summary(model, input_size)

# Function to visualize the model
def visualize_model(model, input_tensor):
    model.eval()
    output = model(input_tensor)
    dot = make_dot(output, params=dict(model.named_parameters()))
    return dot


In [8]:

# Define input size
input_size = (3, 64, 64)  # Example: (channels, height, width)

# Create the model
model = MambaAttentionUNET(
    in_channels=input_size[0], 
    out_channels=1, 
    features=[32, 64, 128, 256]
)

# Display the model summary
print("Model Summary:")
model_summary(model, input_size)

# Create a dummy input tensor for visualization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_tensor = torch.randn(1, *input_size).to(device)


Model Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 64]             896
         LayerNorm-2             [-1, 4096, 32]              64
            Conv1d-3             [-1, 32, 4096]             128
              GELU-4             [-1, 32, 4096]               0
            Conv1d-5             [-1, 32, 4096]           1,056
SimplifiedMambaBlock-6           [-1, 32, 64, 64]               0
         LayerNorm-7             [-1, 4096, 32]              64
            Linear-8             [-1, 4096, 96]           3,168
            Linear-9             [-1, 4096, 32]           1,056
EfficientAttentionBlock-10           [-1, 32, 64, 64]               0
      BatchNorm2d-11           [-1, 32, 64, 64]              64
             GELU-12           [-1, 32, 64, 64]               0
        DownBlock-13           [-1, 32, 64, 64]               0
        MaxPool2

In [9]:

# Generate the visualization graph
print("\nGenerating model visualization...")
dot = visualize_model(model, input_tensor)

# Save the visualization as a file (e.g., PNG)
dot.format = "png"
dot.render("mamba_attention_unet_graph")
print("Model visualization saved as 'mamba_attention_unet_graph.png'")



Generating model visualization...


ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH