### Multi-step segmentation model
This notebook demonstrates the `MultiStepSegmentationModel` defined in `multistep_model.py`.
Make sure `multistep_model.py` is in the same directory or your Python path.

#### 1. Model Instantiation
First, we import the model and instantiate it with the necessary parameters.

In [1]:
from multistep_model import MultiStepSegmentationModel
import torch

# Define model parameters (matching the example in multistep_model.py for consistency)
img_channels = 1       # E.g., 1 for grayscale MRIs
init_mask_channels = 1 # E.g., 1 for a binary initial estimate mask
n_classes = 1          # E.g., 1 for binary segmentation (foreground vs background)
cnn_features = 4     # Hidden features in the BasePatchCNN's CNNBlock
D_orig, H_orig, W_orig = 64, 128, 128 # Original dimensions (Depth, Height, Width)

# Patch dimensions are derived (D/4, H/4, W/4)
p_d, p_h, p_w = D_orig//4, H_orig//4, W_orig//4

print(f"Instantiating MultiStepSegmentationModel with:")
print(f"  img_channels={img_channels}, init_mask_channels={init_mask_channels}, n_classes={n_classes}")
print(f"  cnn_features={cnn_features}, D_orig={D_orig}, H_orig={H_orig}, W_orig={W_orig}")
print(f"  patch_d={p_d}, patch_h={p_h}, patch_w={p_w}")

try:
    model = MultiStepSegmentationModel(
        image_channels=img_channels, 
        initial_mask_channels=init_mask_channels,
        num_classes=n_classes,
        base_cnn_hidden_features=cnn_features,
        patch_size_d=p_d, 
        patch_size_h=p_h, 
        patch_size_w=p_w,
        verbose=False  # Set to True for detailed print output from the model, False for cleaner notebook
    )
    print("[SUCCESS] Model instantiated successfully.")
except TypeError as e:
    print(f"[ERROR] TypeError during model instantiation: {e}")
    print(f"This likely means the __init__ method in multistep_model.py does not match the parameters being passed.")
    print(f"Please ensure multistep_model.py is saved with the correct MultiStepSegmentationModel class definition and RESTART THE KERNEL.")
except Exception as e:
    print(f"[ERROR] An unexpected error occurred during model instantiation: {e}")

Instantiating MultiStepSegmentationModel with:
  img_channels=1, init_mask_channels=1, n_classes=1
  cnn_features=4, D_orig=64, H_orig=128, W_orig=128
  patch_d=16, patch_h=32, patch_w=32
[BasePatchCNN init] In channels: 3, Hidden features: 4, Num classes: 1
[CNNBlock init] In: 3, Out: 4
[SUCCESS] Model instantiated successfully.


#### 2. Test Forward Pass
Next, we test the forward pass with dummy data.

In [2]:
if 'model' in globals():
    print("\nTesting forward pass...")
    batch_size = 1
    dummy_image = torch.randn(batch_size, img_channels, D_orig, H_orig, W_orig)
    dummy_initial_mask = torch.randn(batch_size, init_mask_channels, D_orig, H_orig, W_orig)
    
    try:
        # The model returns three sets of logits
        logits_coarse, logits_refine1, logits_final = model(dummy_image, dummy_initial_mask)
        print(f"[SUCCESS] Dummy forward pass successful.")
        print(f"  Coarse logits shape: {logits_coarse.shape}")
        print(f"  Refine1 logits shape: {logits_refine1.shape}")
        print(f"  Final logits shape: {logits_final.shape}")
    except Exception as e:
        print(f"[ERROR] Error during dummy forward pass: {e}")
else:
    print("Model not instantiated. Please run the previous cell successfully first.")


Testing forward pass...
 [BasePatchCNN fwd] Input patch shape: torch.Size([1, 3, 16, 32, 32])
  [CNNBlock fwd] Input shape: torch.Size([1, 3, 16, 32, 32])
  [CNNBlock fwd] Output shape: torch.Size([1, 4, 16, 32, 32])
 [BasePatchCNN fwd] Features after CNNBlock shape: torch.Size([1, 4, 16, 32, 32])
 [BasePatchCNN fwd] Output logits patch shape: torch.Size([1, 1, 16, 32, 32])
[tensor_to_patches] Input tensor shape: torch.Size([1, 2, 32, 64, 64]), Target patch dims: D=16, H=32, W=32
[tensor_to_patches] Num patches: D_n=2, H_n=2, W_n=2
[tensor_to_patches] Output patches shape: torch.Size([8, 2, 16, 32, 32])
[tensor_to_patches] Input tensor shape: torch.Size([1, 1, 32, 64, 64]), Target patch dims: D=16, H=32, W=32
[tensor_to_patches] Num patches: D_n=2, H_n=2, W_n=2
[tensor_to_patches] Output patches shape: torch.Size([8, 1, 16, 32, 32])
 [BasePatchCNN fwd] Input patch shape: torch.Size([8, 3, 16, 32, 32])
  [CNNBlock fwd] Input shape: torch.Size([8, 3, 16, 32, 32])
  [CNNBlock fwd] Output

#### 3. Install Visualization Libraries
This cell installs `torchviz` and `graphviz` if they are not already present.

In [3]:
# Ensure necessary libraries are installed for visualization
print("\nInstalling/checking torchviz and graphviz...")
!pip install torchviz graphviz


Installing/checking torchviz and graphviz...


#### 4. Model Architecture Visualization
This cell attempts to visualize the model architecture using `torchviz`.
Graphviz also needs to be installed on your system (e.g., download from graphviz.org and add to PATH). 
If `dot.render` fails, it often means the Graphviz executables are not in your system PATH.

In [4]:
if 'model' in globals():
    print("\nAttempting model visualization...")
    try:
        from torchviz import make_dot
        
        if not hasattr(model, 'verbose'):
            print("[ERROR] The 'model' object is missing the 'verbose' attribute.")
            print("This indicates an issue with model instantiation. Ensure multistep_model.py is correct and the kernel was restarted after changes.")
        else:
            # Create dummy inputs again for tracing
            vis_dummy_image = torch.randn(1, img_channels, D_orig, H_orig, W_orig, requires_grad=True)
            vis_dummy_initial_mask = torch.randn(1, init_mask_channels, D_orig, H_orig, W_orig, requires_grad=True)
            
            # Temporarily disable model's internal prints for cleaner viz
            original_verbose_state = model.verbose
            model.verbose = False 
            _, _, vis_output_final = model(vis_dummy_image, vis_dummy_initial_mask)
            model.verbose = original_verbose_state # Restore
            
            params = {name: p for name, p in model.named_parameters() if p.requires_grad}
            
            print("Generating model graph... (this may take a moment)" )
            dot = make_dot(vis_output_final, params=params)
            
            file_path = "multistep_segmentation_model_architecture"
            dot.render(file_path, format="png")
            print(f"[SUCCESS] Model architecture graph saved to '{file_path}.png'.")
            print(f"You can open this file to view the graph.")
            
    except ImportError:
        print("[ERROR] torchviz is not installed. Please run the pip install cell first.")
    except NameError as e:
        print(f"[ERROR] A variable might not be defined (e.g., model, img_channels): {e}. Ensure previous cells ran correctly.")
    except AttributeError as e:
        print(f"[ERROR] AttributeError: {e}. This often means the model object is not what is expected or is missing an attribute.")
    except Exception as e:
        print(f"[ERROR] An unexpected error occurred during visualization: {e}")
        print("This could be due to Graphviz not being installed or not found in the system PATH.")
else:
    print("Model not instantiated. Please run the model instantiation cell successfully first before attempting visualization.")


Attempting model visualization...
 [BasePatchCNN fwd] Input patch shape: torch.Size([1, 3, 16, 32, 32])
  [CNNBlock fwd] Input shape: torch.Size([1, 3, 16, 32, 32])
  [CNNBlock fwd] Output shape: torch.Size([1, 4, 16, 32, 32])
 [BasePatchCNN fwd] Features after CNNBlock shape: torch.Size([1, 4, 16, 32, 32])
 [BasePatchCNN fwd] Output logits patch shape: torch.Size([1, 1, 16, 32, 32])
[tensor_to_patches] Input tensor shape: torch.Size([1, 2, 32, 64, 64]), Target patch dims: D=16, H=32, W=32
[tensor_to_patches] Num patches: D_n=2, H_n=2, W_n=2
[tensor_to_patches] Output patches shape: torch.Size([8, 2, 16, 32, 32])
[tensor_to_patches] Input tensor shape: torch.Size([1, 1, 32, 64, 64]), Target patch dims: D=16, H=32, W=32
[tensor_to_patches] Num patches: D_n=2, H_n=2, W_n=2
[tensor_to_patches] Output patches shape: torch.Size([8, 1, 16, 32, 32])
 [BasePatchCNN fwd] Input patch shape: torch.Size([8, 3, 16, 32, 32])
  [CNNBlock fwd] Input shape: torch.Size([8, 3, 16, 32, 32])
  [CNNBlock f