In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

def inception_block_3d(x, filters):
    """3D Inception module with multiple kernel sizes"""
    # 1x1x1 convolution branch
    branch1 = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='relu')(x)
    
    # 3x3x3 convolution branch
    branch2 = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='relu')(x)
    branch2 = layers.Conv3D(filters, (3, 3, 3), padding='same', activation='relu')(branch2)
    
    # 5x5x5 convolution branch (implemented as two 3x3x3)
    branch3 = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='relu')(x)
    branch3 = layers.Conv3D(filters, (3, 3, 3), padding='same', activation='relu')(branch3)
    branch3 = layers.Conv3D(filters, (3, 3, 3), padding='same', activation='relu')(branch3)
    
    # Pooling branch
    branch4 = layers.MaxPooling3D((3, 3, 3), strides=(1, 1, 1), padding='same')(x)
    branch4 = layers.Conv3D(filters, (1, 1, 1), padding='same', activation='relu')(branch4)
    
    # Concatenate all branches
    output = layers.Concatenate()([branch1, branch2, branch3, branch4])
    return output

def residual_block_3d(x, filters):
    """3D Residual block with skip connection"""
    shortcut = x
    
    # Main path
    x = layers.Conv3D(filters, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(filters, (3, 3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # Adjust shortcut dimensions if needed
    if shortcut.shape[-1] != filters:
        shortcut = layers.Conv3D(filters, (1, 1, 1), padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    
    # Add shortcut
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def squeeze_excitation_block_3d(x, ratio=16):
    """3D Squeeze-and-Excitation block for channel attention"""
    channels = x.shape[-1]
    
    # Squeeze: global average pooling
    se = layers.GlobalAveragePooling3D()(x)
    
    # Excitation: FC layers
    se = layers.Dense(channels // ratio, activation='relu')(se)
    se = layers.Dense(channels, activation='sigmoid')(se)
    
    # Reshape and scale
    se = layers.Reshape((1, 1, 1, channels))(se)
    return layers.Multiply()([x, se])

def MyNet3D(num_classes=2):
    """
    3D Neural Network for VesselMNIST3D
    
    This architecture adapts the 2D hybrid design for 3D medical imaging:
    - Inception modules for multi-scale feature extraction
    - Residual connections for deep learning
    - Squeeze-and-Excitation for channel attention
    - Dense connections across stages
    - Dual attention mechanism
    
    Args:
        num_classes: Number of output classes (check your dataset)
    """
    
    inputs = layers.Input(shape=(28, 28, 28, 1))
    
    # 3D Data augmentation (applied during training only)
    x = layers.RandomFlip("horizontal")(inputs)
    x = layers.RandomRotation(0.1, fill_mode='nearest')(x)
    x = layers.RandomZoom(0.05, fill_mode='nearest')(x)
    x = layers.RandomContrast(0.1)(x)
    
    # Initial feature extraction - reduced filters for 3D
    x = layers.Conv3D(16, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    
    # Stage 1: Multi-scale feature extraction at 28x28x28 resolution
    inception1 = inception_block_3d(x, 8)
    inception1 = layers.BatchNormalization()(inception1)
    inception1 = layers.Dropout(0.2)(inception1)
    
    residual1 = residual_block_3d(inception1, 32)
    se1 = squeeze_excitation_block_3d(residual1)
    
    # Reduce spatial dimensions to 14x14x14
    x = layers.MaxPooling3D((2, 2, 2))(se1)
    
    # Stage 2: Deeper feature learning at 14x14x14 resolution
    inception2 = inception_block_3d(x, 12)
    inception2 = layers.BatchNormalization()(inception2)
    inception2 = layers.Dropout(0.3)(inception2)
    
    residual2 = residual_block_3d(inception2, 48)
    se2 = squeeze_excitation_block_3d(residual2)
    
    # Dense connection: add features from stage 1 to stage 2
    se1_pooled = layers.MaxPooling3D((2, 2, 2))(se1)
    se1_adjusted = layers.Conv3D(48, (1, 1, 1), padding='same')(se1_pooled)
    dense_concat1 = layers.Add()([se2, se1_adjusted])
    
    # Reduce spatial dimensions to 7x7x7
    x = layers.MaxPooling3D((2, 2, 2))(dense_concat1)
    
    # Stage 3: Deep feature processing at 7x7x7 resolution
    residual3a = residual_block_3d(x, 64)
    residual3a = layers.Dropout(0.35)(residual3a)
    
    residual3b = residual_block_3d(residual3a, 64)
    se3 = squeeze_excitation_block_3d(residual3b)
    
    # Dual attention mechanism
    # Path A: spatial attention
    spatial_attention = layers.Conv3D(1, (7, 7, 7), padding='same', activation='sigmoid')(se3)
    spatial_features = layers.Multiply()([se3, spatial_attention])
    
    # Path B: channel-wise transformation
    channel_features = layers.Conv3D(64, (1, 1, 1), activation='relu')(se3)
    
    # Combine both attention paths
    x = layers.Concatenate()([spatial_features, channel_features])
    x = layers.Conv3D(128, (1, 1, 1), activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    
    # Final feature compression
    x = layers.Conv3D(96, (3, 3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    
    # Multi-scale global pooling
    gap = layers.GlobalAveragePooling3D()(x)
    gmp = layers.GlobalMaxPooling3D()(x)
    x = layers.Concatenate()([gap, gmp])
    
    # Classification head
    x = layers.Dense(192, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    
    x = layers.Dense(96, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    
    # Compile model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Example usage:
# model = MyNet3D(num_classes=2)  # Adjust num_classes based on VesselMNIST3D
# model.summary()