# Packages and Imports

In [None]:
# Install necessary packages
!pip install jupyter-dash
!pip install dash
!pip install oct-vol
!pip install scikit-image

In [None]:
import numpy as np
import plotly.graph_objs as go
from jupyter_dash import JupyterDash
from dash import Dash, html, dcc
from dash.dependencies import Input, Output, State
import dash
from skimage import exposure
from OCTVol import OCTVol

# Load Model

In [None]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet_focused(nn.Module):

    def __init__(self, in_channels=2, out_channels=1):

        super().__init__()
#------------- ENCODER -------------
        # Level 1
        self.enc_conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.enc_bn1_1 = nn.BatchNorm2d(32)
        self.dropout1 = nn.Dropout(0.4)
        self.enc_conv1_2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.enc_bn1_2 = nn.BatchNorm2d(32)

        # Level 2
        self.enc_conv2_1 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.enc_bn2_1 = nn.BatchNorm2d(64)
        self.dropout2 = nn.Dropout(0.4)
        self.enc_conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.enc_bn2_2 = nn.BatchNorm2d(64)

        # Level 3
        self.enc_conv3_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.enc_bn3_1 = nn.BatchNorm2d(128)
        self.dropout3 = nn.Dropout(0.4)
        self.enc_conv3_2 = nn.Conv2d(128,128, kernel_size=3, padding=1)
        self.enc_bn3_2 = nn.BatchNorm2d(128)

        # Level 4
        self.enc_conv4_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.enc_bn4_1 = nn.BatchNorm2d(256)
        self.dropout4 = nn.Dropout(0.4)
        self.enc_conv4_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.enc_bn4_2 = nn.BatchNorm2d(256)

        # Level 4
        self.enc_conv5_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.enc_bn5_1 = nn.BatchNorm2d(512)
        self.dropout5 = nn.Dropout(0.4)
        self.enc_conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.enc_bn5_2 = nn.BatchNorm2d(512)

        # "Bottom" block after third pool
        self.bottom_conv1 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
        self.bottom_conv2 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2)

        #------------- DECODER -------------
        # Up block 5 -> merges skip from x5
        self.upconv5 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec5_conv1 = nn.Conv2d(512 + 512, 512, kernel_size=3, padding=1)
        self.dec5_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        # Up block 4 -> merges skip from x4
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4_conv1 = nn.Conv2d(256 + 256, 256, kernel_size=3, padding=1)
        self.dec4_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        # Up block 3 -> merges skip with x3
        self.upconv3 = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
        self.dec3_conv1 = nn.Conv2d(128+128,128,kernel_size=3,padding=1)
        self.dec3_conv2 = nn.Conv2d(128,128,kernel_size=3,padding=1)

        # Up block 2 -> merges skip with x2
        self.upconv2 = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
        self.dec2_conv1 = nn.Conv2d(64+64,64,kernel_size=3,padding=1)
        self.dec2_conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)

        # Up block 1 -> merges skip with x1
        self.upconv1 = nn.ConvTranspose2d(64,32,kernel_size=2,stride=2)
        self.dec1_conv1 = nn.Conv2d(32+32,32,kernel_size=3,padding=1)
        self.dec1_conv2 = nn.Conv2d(32,32,kernel_size=3,padding=1)

        # Final output
        self.final_conv_seg = nn.Conv2d(32, out_channels, kernel_size=1)  # Segmentation output
        self.final_conv_boundary = nn.Conv2d(32, 1, kernel_size=1)       # Boundary regression output

    def forward(self, x):
        # x shape: (B,2,32,96)   # --- CHANGE ---
        # The rest remains identical...
        x1 = self.dropout1(F.relu(self.enc_bn1_1(self.enc_conv1_1(x))))
        x1 = F.relu(self.enc_bn1_2(self.enc_conv1_2(x1)))
        x1p = self.pool(x1)

        x2 = self.dropout2(F.relu(self.enc_bn2_1(self.enc_conv2_1(x1p))))
        x2 = F.relu(self.enc_bn2_2(self.enc_conv2_2(x2)))
        x2p = self.pool(x2)

        x3 = self.dropout3(F.relu(self.enc_bn3_1(self.enc_conv3_1(x2p))))
        x3 = F.relu(self.enc_bn3_2(self.enc_conv3_2(x3)))
        x3p = self.pool(x3)

        x4 = self.dropout4(F.relu(self.enc_bn4_1(self.enc_conv4_1(x3p))))
        x4 = F.relu(self.enc_bn4_2(self.enc_conv4_2(x4)))
        x4p = self.pool(x4)

        x5 = self.dropout5(F.relu(self.enc_bn5_1(self.enc_conv5_1(x4p))))
        x5 = F.relu(self.enc_bn5_2(self.enc_conv5_2(x5)))
        x5p = self.pool(x5)

        xB = F.relu(self.bottom_conv1(x5p))
        xB = F.relu(self.bottom_conv2(xB))

        u5 = self.upconv5(xB)
        if u5.shape[2:] != x5.shape[2:]:
            diffY = x5.shape[2] - u5.shape[2]
            diffX = x5.shape[3] - u5.shape[3]
            u5 = F.pad(u5, (0, diffX, 0, diffY))

        c5 = torch.cat([u5, x5], dim=1)
        d5 = F.relu(self.dec5_conv1(c5))
        d5 = F.relu(self.dec5_conv2(d5))

        u4 = self.upconv4(d5)
        if u4.shape[2:] != x4.shape[2:]:
            diffY = x4.shape[2] - u4.shape[2]
            diffX = x4.shape[3] - u4.shape[3]
            u4 = F.pad(u4, (0, diffX, 0, diffY))

        c4 = torch.cat([u4, x4], dim=1)
        d4 = F.relu(self.dec4_conv1(c4))
        d4 = F.relu(self.dec4_conv2(d4))

        u3 = self.upconv3(d4)
        if u3.shape[2:] != x3.shape[2:]:
            diffY = x3.shape[2] - u3.shape[2]
            diffX = x3.shape[3] - u3.shape[3]
            u3 = F.pad(u3, (0,diffX,0,diffY))

        c3 = torch.cat([u3, x3], dim=1)
        d3 = F.relu(self.dec3_conv1(c3))
        d3 = F.relu(self.dec3_conv2(d3))

        u2 = self.upconv2(d3)
        if u2.shape[2:] != x2.shape[2:]:
            diffY = x2.shape[2] - u2.shape[2]
            diffX = x2.shape[3] - u2.shape[3]
            u2 = F.pad(u2, (0,diffX,0,diffY))

        c2 = torch.cat([u2, x2], dim=1)
        d2 = F.relu(self.dec2_conv1(c2))
        d2 = F.relu(self.dec2_conv2(d2))

        u1 = self.upconv1(d2)
        if u1.shape[2:] != x1.shape[2:]:
            diffY = x1.shape[2] - u1.shape[2]
            diffX = x1.shape[3] - u1.shape[3]
            u1 = F.pad(u1, (0,diffX,0,diffY))

        c1 = torch.cat([u1, x1], dim=1)
        d1 = F.relu(self.dec1_conv1(c1))
        d1 = F.relu(self.dec1_conv2(d1))


        # FINAL OUTPUTS
        logits_seg = self.final_conv_seg(d1)  # (B,1,32,192)
        logits_boundary = self.final_conv_boundary(d1)  # (B,1,32,192)
        logits_boundary = torch.mean(logits_boundary, dim=2, keepdim=True)  #(B,1,1,192)
        return logits_seg, logits_boundary

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model    = UNet_focused().to(device)

model_dir = "/content/drive/MyDrive/Colab Notebooks/Models" ####### Update File path
weights_fn = "UNet_focused.pth"     ####### Update File path
state_dict = torch.load(os.path.join(model_dir, weights_fn), map_location=device)


if next(iter(state_dict)).startswith("module."):
    from collections import OrderedDict
    state_dict = OrderedDict(
        [(k.replace("module.", "", 1), v) for k, v in state_dict.items()]
    )

model.load_state_dict(state_dict, strict=True)
model.eval()

print("✅  Weights loaded from", weights_fn, "and model set to eval mode.")



# Pipeline on original data

## Editor

In [None]:

import os, glob, pickle, re
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import exposure


from jupyter_dash import JupyterDash
import dash
from dash import dcc, html, Output, Input, State
import plotly.graph_objs as go


model.eval()

final_corrected_boundary = None
final_selected_boundary_key = None
corrected_x_range = None
edited_b_scan_index = 24  # The B-scan to edit


# Load the .vol file
global editor_volume
vol_file_path = "/content/drive/MyDrive/Colab Notebooks/Zip Files/Data 1/ZendTo-3C8kBGYWAMNWKpqn/EYE00270_E_4485_cropped_segmented_corrected.vol" ####### Update File path
editor_volume = OCTVol(vol_file_path)
epsilon = 1e-6
edited_b_scan_index = 24  # The B-scan to edit
b_scan_data = editor_volume.b_scans[:, :, edited_b_scan_index]  # (size_z, size_x)

# Process the B-scan image.
b_scan_data_log = 20 * np.log10(b_scan_data + epsilon)
b_scan_data_log_norm = (b_scan_data_log - np.min(b_scan_data_log)) / (np.max(b_scan_data_log) - np.min(b_scan_data_log))
b_scan_data_processed = exposure.equalize_adapthist(b_scan_data_log_norm, clip_limit=0.03)

# Get segmentation boundaries from the header.
num_boundaries = int(editor_volume.b_scan_header['num_seg'][edited_b_scan_index])
size_x = editor_volume.header['size_x']
size_z = editor_volume.header['size_z']

x = np.arange(size_x)
boundaries = {}
for i_boundary in range(num_boundaries):
    boundary_key = f'boundary_{i_boundary + 1}'
    boundary_data = editor_volume.b_scan_header[boundary_key][edited_b_scan_index, :]
    invalid_value = np.float32(3.4028235e+38)
    boundary_data = np.where(boundary_data == invalid_value, np.nan, boundary_data)
    boundaries[boundary_key] = boundary_data

# Initialise dictionaries for modifications and history.
modified_boundaries = boundaries.copy()
boundary_history = {key: [boundaries[key].copy()] for key in boundaries.keys()}



# Create B-scan image with boundaries.
bscan_trace = go.Heatmap(
    z=b_scan_data_processed,
    colorscale='cividis',
    showscale=False,
    name='B-scan'
)

def create_boundary_traces(selected_boundary=None):
    traces = []
    for key, boundary in modified_boundaries.items():
        visible = (key == selected_boundary) or (selected_boundary is None)
        traces.append(
            go.Scatter(
                x=x,
                y=boundary,
                mode='lines',
                name=key,
                line=dict(width=2),
                visible=visible,
            )
        )
    return traces

fig = go.Figure(data=[bscan_trace] + create_boundary_traces())
fig.update_layout(
    title=f"Interactive B-scan Boundary Editor (B-scan #{edited_b_scan_index + 1})",
    xaxis=dict(
        title="X (Pixels)",
        scaleanchor="y",
        range=[0, 500],
        constrain="domain",
    ),
    yaxis=dict(
        title="Depth (Pixels)",
        autorange="reversed",
        range=[0, 500],
        constrain="domain",
    ),
    width=600,
    height=600,
    dragmode="drawopenpath",
    hovermode='closest',
)

# interactive Dash app.
app = JupyterDash(__name__)
app.layout = html.Div([
    dcc.Graph(
        id='graph',
        figure=fig
    ),
    html.Div([
        html.Button('Update Boundary', id='update-button', n_clicks=0),
        html.Button('Undo', id='undo-button', n_clicks=0),
        html.Button('Reset', id='reset-button', n_clicks=0),
        html.Button('Print Boundary Points', id='print-button', n_clicks=0),
        html.Button('Finish Editing', id='finish-button', n_clicks=0),
    ]),
    dcc.Dropdown(
        id='boundary-dropdown',
        options=[{'label': key, 'value': key} for key in boundaries.keys()],
        value=None,
        placeholder='Select a boundary to edit',
        clearable=False
    ),
    html.Div(id='output')
])

@app.callback(
    [Output('graph', 'figure'),
     Output('output', 'children')],
    [Input('update-button', 'n_clicks'),
     Input('undo-button', 'n_clicks'),
     Input('reset-button', 'n_clicks'),
     Input('print-button', 'n_clicks'),
     Input('finish-button', 'n_clicks'),
     Input('boundary-dropdown', 'value')],
    [State('graph', 'relayoutData'),
     State('graph', 'figure')],
    prevent_initial_call=True
)
def update_boundary(update_clicks, undo_clicks, reset_clicks, print_clicks, finish_clicks, selected_boundary, relayoutData, current_fig):
    global final_corrected_boundary, final_selected_boundary_key, corrected_x_range
    ctx = dash.callback_context
    if not ctx.triggered:
        return current_fig, ""
    triggered = ctx.triggered[0]['prop_id'].split('.')[0]

    if selected_boundary is None:
        return current_fig, "Please select a boundary to edit."

    fig = go.Figure(current_fig)

    if triggered == 'print-button':
        boundary_points = modified_boundaries[selected_boundary]
        coordinates = list(zip(x, boundary_points))
        return fig, f"Boundary {selected_boundary} points:\n{coordinates}"

    elif triggered == 'boundary-dropdown':
        # Update visibility of boundary traces.
        for trace in fig.data:
            if trace.name.startswith('boundary_'):
                trace.visible = (trace.name == selected_boundary)
        fig.update_layout(dragmode="drawopenpath", shapes=[])
        return fig, f"Selected boundary: {selected_boundary}"

    elif triggered == 'update-button':
        if relayoutData and 'shapes' in relayoutData:
            shapes = relayoutData['shapes']
            if len(shapes) == 0:
                return fig, "No shape drawn"
            shape = shapes[-1]  # Use the last drawn shape.
            path = shape['path']
            matches = re.findall(r'[ML]\s*([-\d.]+),\s*([-\d.]+)', path)
            x_coords, y_coords = [], []
            for x_str, y_str in matches:
                x_coords.append(float(x_str))
                y_coords.append(float(y_str))
            x_coords = np.array(x_coords).astype(int)
            y_coords = np.array(y_coords)
            x_coords = np.clip(x_coords, 0, size_x - 1)
            sorted_indices = np.argsort(x_coords)
            x_coords = x_coords[sorted_indices]
            y_coords = y_coords[sorted_indices]
            full_x = np.arange(x_coords[0], x_coords[-1] + 1)
            full_y = np.interp(full_x, x_coords, y_coords)
            modified_boundary = modified_boundaries[selected_boundary]
            modified_boundary[full_x] = full_y

            # Record the horizontal range where the correction was applied.
            corrected_x_range = (int(np.min(x_coords)), int(np.max(x_coords)))

            for trace in fig.data:
                if trace.name == selected_boundary:
                    trace.y = modified_boundary
            fig.update_layout(shapes=[])
            boundary_history[selected_boundary].append(modified_boundary.copy())
            return fig, f"Boundary {selected_boundary} updated. Corrected x-range: {corrected_x_range}"
        else:
            return fig, "No shape drawn"

    elif triggered == 'undo-button':
        history = boundary_history[selected_boundary]
        if len(history) > 1:
            history.pop()
            modified_boundary = history[-1]
            modified_boundaries[selected_boundary] = modified_boundary.copy()
            for trace in fig.data:
                if trace.name == selected_boundary:
                    trace.y = modified_boundary
            return fig, f"Undo last change on boundary {selected_boundary}."
        else:
            return fig, "Nothing to undo."

    elif triggered == 'reset-button':
        modified_boundary = boundaries[selected_boundary].copy()
        modified_boundaries[selected_boundary] = modified_boundary
        boundary_history[selected_boundary] = [modified_boundary.copy()]
        for trace in fig.data:
            if trace.name == selected_boundary:
                trace.y = modified_boundary
        return fig, f"Boundary {selected_boundary} reset to original."

    elif triggered == 'finish-button':
        # Save the final corrected boundary and key.
        final_corrected_boundary = modified_boundaries[selected_boundary].copy()
        final_selected_boundary_key = selected_boundary
        return fig, f"Finished editing boundary {selected_boundary}."

    else:
        return fig, ""

# Run the interactive editor.
app.run(mode='inline')


## Propagation

In [None]:
# Run after editing

if final_corrected_boundary is None or final_selected_boundary_key is None:
    raise ValueError("No corrected boundary found. Finish editing first.")

if 'editor_volume' not in globals():
    raise ValueError("Editor volume not found. Did you run the editor?")

vol = editor_volume

def preprocess_bscan(bscan):
    epsilon = 1e-6
    bscan_log = 20 * np.log10(bscan + epsilon)
    bscan_norm = (bscan_log - np.min(bscan_log)) / (np.max(bscan_log) - np.min(bscan_log))
    bscan_eq = exposure.equalize_adapthist(bscan_norm, clip_limit=0.03)
    return bscan_eq

# -- Propagation Functions --
def crop_patch(image, ref_boundary, crop_height=32, crop_width=192):
    H, W = image.shape
    global corrected_x_range
    if corrected_x_range is not None:
        corrected_center = (corrected_x_range[0] + corrected_x_range[1]) // 2
        start_w = max(0, corrected_center - crop_width // 2)
        if start_w + crop_width > W:
            start_w = W - crop_width
    else:
        start_w = (W - crop_width) // 2
    end_w = start_w + crop_width
    img_crop = image[:, start_w:end_w]
    ref_boundary_crop = ref_boundary[start_w:end_w]

    if corrected_x_range is not None:
        corrected_start = max(start_w, corrected_x_range[0])
        corrected_end = min(end_w, corrected_x_range[1] + 1)
        if corrected_end > corrected_start:
            center = np.nanmedian(ref_boundary[corrected_start:corrected_end])
        else:
            center = np.nanmedian(ref_boundary_crop)
    else:
        center = np.nanmedian(ref_boundary_crop)

    top = int(round(center - crop_height / 2))
    if top < 0:
        top = 0
    if top + crop_height > H:
        top = H - crop_height

    patch_img = img_crop[top:top+crop_height, :]
    patch_ref_boundary = ref_boundary_crop - top
    return patch_img, patch_ref_boundary, top, start_w

def boundary_to_mask(boundary, patch_shape=(32,192)):
    mask = np.zeros(patch_shape, dtype=np.float32)
    for col in range(min(len(boundary), patch_shape[1])):
        val = boundary[col]
        if not np.isnan(val):
            row = int(round(val))
            row = max(0, min(row, patch_shape[0]-1))
            mask[row, col] = 1.0
    return mask

# -- Start propagation --
if vol.b_scans.shape[2] <= edited_b_scan_index:
    raise ValueError("Edited index out of bounds.")

img0 = preprocess_bscan(vol.b_scans[:, :, edited_b_scan_index])
gt_boundary0 = final_corrected_boundary.copy()
patch_img0, patch_gt_boundary0, top0, start_w0 = crop_patch(img0, gt_boundary0)
top_list = [top0]
start_w_list = [start_w0]
gt_boundaries_patch = [patch_gt_boundary0]

prev_full_boundary = gt_boundary0.copy()

propagation_start = edited_b_scan_index + 1
num_slices = min(15, vol.b_scans.shape[2] - propagation_start)
print(f"Propagating on {num_slices} slices starting from B-scan {propagation_start + 1}")

img1 = preprocess_bscan(vol.b_scans[:, :, propagation_start])
gt_boundary1 = np.array(vol.b_scan_header[final_selected_boundary_key][propagation_start, :])
current_top = top0
current_start_w = start_w0
patch_img = img1[current_top:current_top+32, current_start_w:current_start_w+192]
patch_ref_boundary = prev_full_boundary[current_start_w:current_start_w+192] - current_top
top_list.append(current_top)
start_w_list.append(current_start_w)

plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(patch_img, cmap='gray')
plt.title("Slice Propagation: Input Patch")
plt.axis('off')

plt.subplot(1,3,2)
ref_mask = boundary_to_mask(patch_ref_boundary, (32,192))
plt.imshow(ref_mask, cmap='gray')
plt.title("Reference Boundary Mask")
plt.axis('off')

input_tensor = torch.tensor(np.stack([patch_img, ref_mask], axis=0),
                            dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
    _, logits_boundary = model(input_tensor)
    pred_boundary_patch = logits_boundary.squeeze().cpu().numpy()

plt.subplot(1,3,3)
plt.imshow(patch_img, cmap='gray')
plt.plot(np.arange(192), pred_boundary_patch, 'b-', linewidth=2)
plt.axis('off')
plt.legend()
plt.show()

pred_boundaries_patch = [pred_boundary_patch]
pred_boundaries_orig = [pred_boundary_patch + current_top]
gt_boundary_patch = gt_boundary1[current_start_w:current_start_w+192] - current_top
gt_boundaries_patch.append(gt_boundary_patch)
prev_pred_boundary_orig = pred_boundaries_orig[-1].copy()

# Next slices
for i in range(propagation_start + 1, propagation_start + num_slices):
    img_i = preprocess_bscan(vol.b_scans[:, :, i])
    gt_boundary_i = np.array(vol.b_scan_header[final_selected_boundary_key][i, :])
    H, W = img_i.shape
    current_start_w = start_w0
    mean_pred = np.nanmean(prev_pred_boundary_orig)
    current_top = int(round(mean_pred - 32/2))
    if current_top < 0:
        current_top = 0
    if current_top + 32 > H:
        current_top = H - 32

    patch_img = img_i[current_top:current_top+32, current_start_w:current_start_w+192]
    patch_ref_boundary = prev_pred_boundary_orig - current_top
    top_list.append(current_top)
    start_w_list.append(current_start_w)

    plt.figure(figsize=(15,5))
    plt.subplot(1,3,1)
    plt.imshow(patch_img, cmap='gray')
    plt.title(f"Slice {i+1}: Input Patch")
    plt.axis('off')

    plt.subplot(1,3,2)
    ref_mask = boundary_to_mask(patch_ref_boundary, (32,192))
    plt.imshow(ref_mask, cmap='gray')
    plt.title(f"Slice {i+1}: Ref Boundary Mask")
    plt.axis('off')

    input_tensor = torch.tensor(np.stack([patch_img, ref_mask], axis=0),
                                dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        _, logits_boundary = model(input_tensor)
        pred_boundary_patch = logits_boundary.squeeze().cpu().numpy()

    plt.subplot(1,3,3)
    plt.imshow(patch_img, cmap='gray')
    plt.plot(np.arange(192), pred_boundary_patch, 'b-', linewidth=2)
    plt.title(f"Slice {i+1}: Prediction")
    plt.axis('off')
    plt.legend()
    plt.show()

    pred_boundary_orig = pred_boundary_patch + current_top
    pred_boundaries_patch.append(pred_boundary_patch)
    pred_boundaries_orig.append(pred_boundary_orig)
    gt_boundary_patch = gt_boundary_i[current_start_w:current_start_w+192] - current_top
    gt_boundaries_patch.append(gt_boundary_patch)
    prev_pred_boundary_orig = pred_boundary_orig.copy()

print("Done cropping & propagating.")


## Overwrite with corrected boundaries

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def join_boundaries(orig_boundary, corrected_boundary, corrected_x_range, blend_width=10):
    """
    Merge corrected boundary segment into the original
    """
    new_boundary = orig_boundary.copy()

    if len(corrected_boundary) == len(orig_boundary):
        x_start, x_end = map(int, corrected_x_range)
        corrected_segment = corrected_boundary[x_start:x_end+1]
    else:
        x_start = int(corrected_x_range[0])
        corrected_segment = corrected_boundary
        x_end = x_start + len(corrected_segment) - 1

    new_boundary[x_start:x_end+1] = corrected_segment

    # Blend left edge
    if x_start - blend_width >= 0:
        for x in range(x_start - blend_width, x_start):
            w = (x - (x_start - blend_width)) / blend_width
            new_boundary[x] = (1 - w) * orig_boundary[x] + w * corrected_segment[0]
    else:
        new_boundary[:x_start] = orig_boundary[:x_start]

    # Blend right edge
    if x_end + blend_width < len(orig_boundary):
        for x in range(x_end + 1, x_end + blend_width + 1):
            w = 1 - (x - (x_end + 1)) / blend_width
            new_boundary[x] = w * orig_boundary[x] + (1 - w) * corrected_segment[-1]
    else:
        new_boundary[x_end+1:] = orig_boundary[x_end+1:]

    return new_boundary

# Merge corrected boundaries into volume
vol = editor_volume

final_boundaries = []

# Edited slice
orig_boundary_0 = np.array(vol.b_scan_header[final_selected_boundary_key][edited_b_scan_index, :])
final_boundary_0 = join_boundaries(orig_boundary_0, final_corrected_boundary, corrected_x_range)
final_boundaries.append(final_boundary_0)

proc_img0 = preprocess_bscan(vol.b_scans[:, :, edited_b_scan_index])

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(proc_img0, cmap='gray')
plt.plot(np.arange(len(final_boundary_0)), final_boundary_0, 'r-', linewidth=2, label="Updated")
plt.plot(np.arange(len(orig_boundary_0)), orig_boundary_0, 'g-', linewidth=2, label="Original")
plt.title(f"Slice {edited_b_scan_index+1} w/ Updated Boundary")
plt.gca().invert_yaxis()
plt.legend()

plt.subplot(1,2,2)
plt.imshow(proc_img0, cmap='gray')
plt.plot(np.arange(len(orig_boundary_0)), orig_boundary_0, 'g-', linewidth=2)
plt.title(f"Slice {edited_b_scan_index+1} Original")
plt.gca().invert_yaxis()
plt.legend()
plt.show()

# Propagated slices
propagation_start = edited_b_scan_index + 1
num_prop = len(pred_boundaries_orig)

for j in range(num_prop):
    slice_idx = propagation_start + j
    orig_boundary_i = np.array(vol.b_scan_header[final_selected_boundary_key][slice_idx, :])
    corrected_patch = pred_boundaries_orig[j]
    patch_corrected_x_range = (start_w0, start_w0 + len(corrected_patch))

    final_boundary_i = join_boundaries(orig_boundary_i, corrected_patch, patch_corrected_x_range)
    final_boundaries.append(final_boundary_i)

    proc_img = preprocess_bscan(vol.b_scans[:, :, slice_idx])

    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.imshow(proc_img, cmap='gray')
    plt.plot(np.arange(len(orig_boundary_i)), orig_boundary_i, 'g-', linewidth=2, label="Original")
    plt.plot(np.arange(len(final_boundary_i)), final_boundary_i, 'r-', linewidth=2, label="Updated")
    plt.title(f"Slice {slice_idx+1} w/ Updated Boundary")
    plt.gca().invert_yaxis()
    plt.legend()

    plt.subplot(1,2,2)
    plt.imshow(proc_img, cmap='gray')
    plt.plot(np.arange(len(orig_boundary_i)), orig_boundary_i, 'g-', linewidth=2)
    plt.title(f"Slice {slice_idx+1} Original")
    plt.gca().invert_yaxis()
    plt.legend()
    plt.show()

print("Merged + plotted corrected boundaries for all slices.")


# Pipeline on simulated errors

## Editor

In [None]:


# Imports
import os, glob, pickle, re
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import exposure

# Dashboard tools
from jupyter_dash import JupyterDash
import dash
from dash import dcc, html, Output, Input, State
import plotly.graph_objs as go


model.eval()

# Globals for storing editing info
final_corrected_boundary = None
final_selected_boundary_key = None
corrected_x_range = None
edited_b_scan_index = 24

# Load the .vol file
global editor_volume
vol_file_path = "/content/drive/MyDrive/Colab Notebooks/Zip Files/Data 1/ZendTo-3C8kBGYWAMNWKpqn/EYE00270_E_4485_cropped_segmented_corrected.vol" ####### Update File path
editor_volume = OCTVol(vol_file_path)

epsilon = 1e-6
edited_b_scan_index = 24
b_scan_data = editor_volume.b_scans[:, :, edited_b_scan_index]

# Basic processing
b_scan_data_log = 20 * np.log10(b_scan_data + epsilon)
b_scan_data_log_norm = (b_scan_data_log - np.min(b_scan_data_log)) / (np.max(b_scan_data_log) - np.min(b_scan_data_log))
b_scan_data_processed = exposure.equalize_adapthist(b_scan_data_log_norm, clip_limit=0.03)

# Load segment boundaries
num_boundaries = int(editor_volume.b_scan_header['num_seg'][edited_b_scan_index])
size_x = editor_volume.header['size_x']
size_z = editor_volume.header['size_z']

x = np.arange(size_x)
boundaries = {}
for i_boundary in range(num_boundaries):
    key = f'boundary_{i_boundary + 1}'
    boundary_data = editor_volume.b_scan_header[key][edited_b_scan_index, :]
    invalid_value = np.float32(3.4028235e+38)
    boundary_data = np.where(boundary_data == invalid_value, np.nan, boundary_data)
    boundaries[key] = boundary_data

# Add fake noise for testing
def simulate_noisy_boundary(boundary, region=(50,200), seg_min=5, seg_max=20, shift_min=-10, shift_max=10):
    new_boundary = boundary.copy()
    start, end = region
    i = start
    while i < end:
        available = end - i
        if available < seg_min:
            break
        seg_length = np.random.randint(seg_min, min(seg_max, available) + 1)
        shift = np.random.randint(shift_min, shift_max + 1)
        new_boundary[i:i+seg_length] += shift
        i += seg_length
    return new_boundary

# Apply simulated error
simulated_boundary = simulate_noisy_boundary(boundaries['boundary_1'],
                                             region=(50,200),
                                             seg_min=5,
                                             seg_max=20,
                                             shift_min=-10,
                                             shift_max=10)
modified_boundaries = boundaries.copy()
modified_boundaries['boundary_1'] = simulated_boundary.copy()

# Boundary history tracking
boundary_history = {key: [boundaries[key].copy()] for key in boundaries.keys()}
boundary_history['boundary_1'] = [simulated_boundary.copy()]

# Create bscan heatmap
bscan_trace = go.Heatmap(
    z=b_scan_data_processed,
    colorscale='cividis',
    showscale=False,
    name='B-scan'
)

# Overlay the boundaries
def create_boundary_traces(selected_boundary=None):
    traces = []
    for key, boundary in modified_boundaries.items():
        visible = (key == selected_boundary) or (selected_boundary is None)
        traces.append(
            go.Scatter(
                x=x,
                y=boundary,
                mode='lines',
                name=key,
                line=dict(width=2),
                visible=visible,
            )
        )
    return traces

fig = go.Figure(data=[bscan_trace] + create_boundary_traces())
fig.update_layout(
    title=f"Interactive B-scan Boundary Editor (B-scan #{edited_b_scan_index + 1})",
    xaxis=dict(
        title="X (Pixels)",
        scaleanchor="y",
        range=[0, 500],
        constrain="domain",
    ),
    yaxis=dict(
        title="Depth (Pixels)",
        autorange="reversed",
        range=[0, 500],
        constrain="domain",
    ),
    width=600,
    height=600,
    dragmode="drawopenpath",
    hovermode='closest',
)

# Dash app layout
app = JupyterDash(__name__)
app.layout = html.Div([
    dcc.Graph(
        id='graph',
        figure=fig
    ),
    html.Div([
        html.Button('Update Boundary', id='update-button', n_clicks=0),
        html.Button('Undo', id='undo-button', n_clicks=0),
        html.Button('Reset', id='reset-button', n_clicks=0),
        html.Button('Print Boundary Points', id='print-button', n_clicks=0),
        html.Button('Finish Editing', id='finish-button', n_clicks=0),
    ]),
    dcc.Dropdown(
        id='boundary-dropdown',
        options=[{'label': key, 'value': key} for key in boundaries.keys()],
        value=None,
        placeholder='Select a boundary to edit',
        clearable=False
    ),
    html.Div(id='output')
])

@app.callback(
    [Output('graph', 'figure'),
     Output('output', 'children')],
    [Input('update-button', 'n_clicks'),
     Input('undo-button', 'n_clicks'),
     Input('reset-button', 'n_clicks'),
     Input('print-button', 'n_clicks'),
     Input('finish-button', 'n_clicks'),
     Input('boundary-dropdown', 'value')],
    [State('graph', 'relayoutData'),
     State('graph', 'figure')],
    prevent_initial_call=True
)
def update_boundary(update_clicks, undo_clicks, reset_clicks, print_clicks, finish_clicks, selected_boundary, relayoutData, current_fig):
    global final_corrected_boundary, final_selected_boundary_key, corrected_x_range
    ctx = dash.callback_context
    if not ctx.triggered:
        return current_fig, ""
    triggered = ctx.triggered[0]['prop_id'].split('.')[0]

    if selected_boundary is None:
        return current_fig, "Please select a boundary to edit."

    fig = go.Figure(current_fig)

    if triggered == 'print-button':
        boundary_points = modified_boundaries[selected_boundary]
        coordinates = list(zip(x, boundary_points))
        return fig, f"Boundary {selected_boundary} points:\n{coordinates}"

    elif triggered == 'boundary-dropdown':
        # Toggle vis
        for trace in fig.data:
            if trace.name.startswith('boundary_'):
                trace.visible = (trace.name == selected_boundary)
        fig.update_layout(dragmode="drawopenpath", shapes=[])
        return fig, f"Selected boundary: {selected_boundary}"

    elif triggered == 'update-button':
        if relayoutData and 'shapes' in relayoutData:
            shapes = relayoutData['shapes']
            if len(shapes) == 0:
                return fig, "No shape drawn"
            shape = shapes[-1]
            path = shape['path']
            matches = re.findall(r'[ML]\s*([-\d.]+),\s*([-\d.]+)', path)
            x_coords, y_coords = [], []
            for x_str, y_str in matches:
                x_coords.append(float(x_str))
                y_coords.append(float(y_str))
            x_coords = np.array(x_coords).astype(int)
            y_coords = np.array(y_coords)
            x_coords = np.clip(x_coords, 0, size_x - 1)
            sorted_indices = np.argsort(x_coords)
            x_coords = x_coords[sorted_indices]
            y_coords = y_coords[sorted_indices]
            full_x = np.arange(x_coords[0], x_coords[-1] + 1)
            full_y = np.interp(full_x, x_coords, y_coords)
            modified_boundary = modified_boundaries[selected_boundary]
            modified_boundary[full_x] = full_y

            corrected_x_range = (int(np.min(x_coords)), int(np.max(x_coords)))

            for trace in fig.data:
                if trace.name == selected_boundary:
                    trace.y = modified_boundary
            fig.update_layout(shapes=[])
            boundary_history[selected_boundary].append(modified_boundary.copy())
            return fig, f"Boundary {selected_boundary} updated. Corrected x-range: {corrected_x_range}"
        else:
            return fig, "No shape drawn"

    elif triggered == 'undo-button':
        history = boundary_history[selected_boundary]
        if len(history) > 1:
            history.pop()
            modified_boundary = history[-1]
            modified_boundaries[selected_boundary] = modified_boundary.copy()
            for trace in fig.data:
                if trace.name == selected_boundary:
                    trace.y = modified_boundary
            return fig, f"Undo last change on boundary {selected_boundary}."
        else:
            return fig, "Nothing to undo."

    elif triggered == 'reset-button':
        modified_boundary = boundaries[selected_boundary].copy()
        modified_boundaries[selected_boundary] = modified_boundary
        boundary_history[selected_boundary] = [modified_boundary.copy()]
        for trace in fig.data:
            if trace.name == selected_boundary:
                trace.y = modified_boundary
        return fig, f"Boundary {selected_boundary} reset to original."

    elif triggered == 'finish-button':
        final_corrected_boundary = modified_boundaries[selected_boundary].copy()
        final_selected_boundary_key = selected_boundary
        return fig, f"Finished editing boundary {selected_boundary}."

    else:
        return fig, ""

# Launch the editor
app.run(mode='inline')


In [None]:
# Simulate & overwrite boundary_1 for B-scans 24 to 39

def simulate_noisy_boundary(boundary, region=(50,200), seg_min=5, seg_max=20, shift_min=-10, shift_max=10):
    """
    Adds random shifts to segments within a specified region of the boundary.
    """
    new_boundary = boundary.copy()
    start, end = region
    i = start
    while i < end:
        available = end - i
        if available < seg_min:
            break
        seg_length = np.random.randint(seg_min, min(seg_max, available) + 1)
        shift = np.random.randint(shift_min, shift_max + 1)
        new_boundary[i:i+seg_length] += shift
        i += seg_length
    return new_boundary

# Apply noisy boundaries to selected b-scans
for idx in range(24, 40):
    orig_boundary = np.array(editor_volume.b_scan_header['boundary_1'][idx, :])
    simulated_boundary = simulate_noisy_boundary(orig_boundary,
                                                 region=(50,200),
                                                 seg_min=5,
                                                 seg_max=20,
                                                 shift_min=-10,
                                                 shift_max=10)
    editor_volume.b_scan_header['boundary_1'][idx, :] = simulated_boundary
    print(f"B-scan {idx+1}: boundary_1 updated with noise")

"""
#plot a few results to verify
import matplotlib.pyplot as plt
for idx in [24, 29, 34, 39]:
    orig = np.array(editor_volume.b_scan_header['boundary_1'][idx, :])
    proc_img = preprocess_bscan(editor_volume.b_scans[:, :, idx])
    plt.figure(figsize=(8,6))
    plt.imshow(proc_img, cmap='gray')
    plt.plot(np.arange(len(orig)), orig, 'r-', linewidth=2, label="Simulated 'boundary_1'")
    plt.title(f"B-scan {idx+1} with Simulated boundary_1")
    plt.xlabel("X (pixels)")
    plt.ylabel("Depth (pixels)")
    plt.gca().invert_yaxis()
    plt.legend()
    plt.show()
"""




## Propagation

In [None]:
# Run after editing

if final_corrected_boundary is None or final_selected_boundary_key is None:
    raise ValueError("No corrected boundary found. Finish editing first.")

if 'editor_volume' not in globals():
    raise ValueError("Editor volume not found. Did you run the editor?")

vol = editor_volume

def preprocess_bscan(bscan):
    epsilon = 1e-6
    bscan_log = 20 * np.log10(bscan + epsilon)
    bscan_norm = (bscan_log - np.min(bscan_log)) / (np.max(bscan_log) - np.min(bscan_log))
    bscan_eq = exposure.equalize_adapthist(bscan_norm, clip_limit=0.03)
    return bscan_eq

# -- Propagation Functions --
def crop_patch(image, ref_boundary, crop_height=32, crop_width=192):
    H, W = image.shape
    global corrected_x_range
    if corrected_x_range is not None:
        corrected_center = (corrected_x_range[0] + corrected_x_range[1]) // 2
        start_w = max(0, corrected_center - crop_width // 2)
        if start_w + crop_width > W:
            start_w = W - crop_width
    else:
        start_w = (W - crop_width) // 2
    end_w = start_w + crop_width
    img_crop = image[:, start_w:end_w]
    ref_boundary_crop = ref_boundary[start_w:end_w]

    if corrected_x_range is not None:
        corrected_start = max(start_w, corrected_x_range[0])
        corrected_end = min(end_w, corrected_x_range[1] + 1)
        if corrected_end > corrected_start:
            center = np.nanmedian(ref_boundary[corrected_start:corrected_end])
        else:
            center = np.nanmedian(ref_boundary_crop)
    else:
        center = np.nanmedian(ref_boundary_crop)

    top = int(round(center - crop_height / 2))
    if top < 0:
        top = 0
    if top + crop_height > H:
        top = H - crop_height

    patch_img = img_crop[top:top+crop_height, :]
    patch_ref_boundary = ref_boundary_crop - top
    return patch_img, patch_ref_boundary, top, start_w

def boundary_to_mask(boundary, patch_shape=(32,192)):
    mask = np.zeros(patch_shape, dtype=np.float32)
    for col in range(min(len(boundary), patch_shape[1])):
        val = boundary[col]
        if not np.isnan(val):
            row = int(round(val))
            row = max(0, min(row, patch_shape[0]-1))
            mask[row, col] = 1.0
    return mask

# -- Start propagation --
if vol.b_scans.shape[2] <= edited_b_scan_index:
    raise ValueError("Edited index out of bounds.")

img0 = preprocess_bscan(vol.b_scans[:, :, edited_b_scan_index])
gt_boundary0 = final_corrected_boundary.copy()
patch_img0, patch_gt_boundary0, top0, start_w0 = crop_patch(img0, gt_boundary0)
top_list = [top0]
start_w_list = [start_w0]
gt_boundaries_patch = [patch_gt_boundary0]

prev_full_boundary = gt_boundary0.copy()

propagation_start = edited_b_scan_index + 1
num_slices = min(15, vol.b_scans.shape[2] - propagation_start)
print(f"Propagating on {num_slices} slices starting from B-scan {propagation_start + 1}")

img1 = preprocess_bscan(vol.b_scans[:, :, propagation_start])
gt_boundary1 = np.array(vol.b_scan_header[final_selected_boundary_key][propagation_start, :])
current_top = top0
current_start_w = start_w0
patch_img = img1[current_top:current_top+32, current_start_w:current_start_w+192]
patch_ref_boundary = prev_full_boundary[current_start_w:current_start_w+192] - current_top
top_list.append(current_top)
start_w_list.append(current_start_w)

plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(patch_img, cmap='gray')
plt.title("Slice Propagation: Input Patch")
plt.axis('off')

plt.subplot(1,3,2)
ref_mask = boundary_to_mask(patch_ref_boundary, (32,192))
plt.imshow(ref_mask, cmap='gray')
plt.title("Reference Boundary Mask")
plt.axis('off')

input_tensor = torch.tensor(np.stack([patch_img, ref_mask], axis=0),
                            dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
    _, logits_boundary = model(input_tensor)
    pred_boundary_patch = logits_boundary.squeeze().cpu().numpy()

plt.subplot(1,3,3)
plt.imshow(patch_img, cmap='gray')
plt.plot(np.arange(192), pred_boundary_patch, 'b-', linewidth=2)
plt.axis('off')
plt.legend()
plt.show()

pred_boundaries_patch = [pred_boundary_patch]
pred_boundaries_orig = [pred_boundary_patch + current_top]
gt_boundary_patch = gt_boundary1[current_start_w:current_start_w+192] - current_top
gt_boundaries_patch.append(gt_boundary_patch)
prev_pred_boundary_orig = pred_boundaries_orig[-1].copy()

# Next slices
for i in range(propagation_start + 1, propagation_start + num_slices):
    img_i = preprocess_bscan(vol.b_scans[:, :, i])
    gt_boundary_i = np.array(vol.b_scan_header[final_selected_boundary_key][i, :])
    H, W = img_i.shape
    current_start_w = start_w0
    mean_pred = np.nanmean(prev_pred_boundary_orig)
    current_top = int(round(mean_pred - 32/2))
    if current_top < 0:
        current_top = 0
    if current_top + 32 > H:
        current_top = H - 32

    patch_img = img_i[current_top:current_top+32, current_start_w:current_start_w+192]
    patch_ref_boundary = prev_pred_boundary_orig - current_top
    top_list.append(current_top)
    start_w_list.append(current_start_w)

    plt.figure(figsize=(15,5))
    plt.subplot(1,3,1)
    plt.imshow(patch_img, cmap='gray')
    plt.title(f"Slice {i+1}: Input Patch")
    plt.axis('off')

    plt.subplot(1,3,2)
    ref_mask = boundary_to_mask(patch_ref_boundary, (32,192))
    plt.imshow(ref_mask, cmap='gray')
    plt.title(f"Slice {i+1}: Ref Boundary Mask")
    plt.axis('off')

    input_tensor = torch.tensor(np.stack([patch_img, ref_mask], axis=0),
                                dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        _, logits_boundary = model(input_tensor)
        pred_boundary_patch = logits_boundary.squeeze().cpu().numpy()

    plt.subplot(1,3,3)
    plt.imshow(patch_img, cmap='gray')
    plt.plot(np.arange(192), pred_boundary_patch, 'b-', linewidth=2)
    plt.title(f"Slice {i+1}: Prediction")
    plt.axis('off')
    plt.legend()
    plt.show()

    pred_boundary_orig = pred_boundary_patch + current_top
    pred_boundaries_patch.append(pred_boundary_patch)
    pred_boundaries_orig.append(pred_boundary_orig)
    gt_boundary_patch = gt_boundary_i[current_start_w:current_start_w+192] - current_top
    gt_boundaries_patch.append(gt_boundary_patch)
    prev_pred_boundary_orig = pred_boundary_orig.copy()

print("Done cropping & propagating.")


## Overwriting

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def join_boundaries(orig_boundary, corrected_boundary, corrected_x_range, blend_width=10):
    """
    Merge corrected boundary segment into the original
    """
    new_boundary = orig_boundary.copy()

    if len(corrected_boundary) == len(orig_boundary):
        x_start, x_end = map(int, corrected_x_range)
        corrected_segment = corrected_boundary[x_start:x_end+1]
    else:
        x_start = int(corrected_x_range[0])
        corrected_segment = corrected_boundary
        x_end = x_start + len(corrected_segment) - 1

    new_boundary[x_start:x_end+1] = corrected_segment

    # Blend left edge
    if x_start - blend_width >= 0:
        for x in range(x_start - blend_width, x_start):
            w = (x - (x_start - blend_width)) / blend_width
            new_boundary[x] = (1 - w) * orig_boundary[x] + w * corrected_segment[0]
    else:
        new_boundary[:x_start] = orig_boundary[:x_start]

    # Blend right edge
    if x_end + blend_width < len(orig_boundary):
        for x in range(x_end + 1, x_end + blend_width + 1):
            w = 1 - (x - (x_end + 1)) / blend_width
            new_boundary[x] = w * orig_boundary[x] + (1 - w) * corrected_segment[-1]
    else:
        new_boundary[x_end+1:] = orig_boundary[x_end+1:]

    return new_boundary

# Merge corrected boundaries into volume
vol = editor_volume

final_boundaries = []

# Edited slice first
orig_boundary_0 = np.array(vol.b_scan_header[final_selected_boundary_key][edited_b_scan_index, :])
final_boundary_0 = join_boundaries(orig_boundary_0, final_corrected_boundary, corrected_x_range)
final_boundaries.append(final_boundary_0)

proc_img0 = preprocess_bscan(vol.b_scans[:, :, edited_b_scan_index])

plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(proc_img0, cmap='gray')
plt.plot(np.arange(len(final_boundary_0)), final_boundary_0, 'r-', linewidth=2, label="Updated")
plt.plot(np.arange(len(orig_boundary_0)), orig_boundary_0, 'g-', linewidth=2, label="Original")
plt.title(f"Slice {edited_b_scan_index+1} w/ Updated Boundary")
#plt.gca().invert_yaxis()
plt.legend()

plt.subplot(1,2,2)
plt.imshow(proc_img0, cmap='gray')
plt.plot(np.arange(len(orig_boundary_0)), orig_boundary_0, 'g-', linewidth=2)
plt.title(f"Slice {edited_b_scan_index+1} Original")
#plt.gca().invert_yaxis()
plt.legend()
plt.show()

# Propagated slices
propagation_start = edited_b_scan_index + 1
num_prop = len(pred_boundaries_orig)

for j in range(num_prop):
    slice_idx = propagation_start + j
    orig_boundary_i = np.array(vol.b_scan_header[final_selected_boundary_key][slice_idx, :])
    corrected_patch = pred_boundaries_orig[j]
    patch_corrected_x_range = (start_w0, start_w0 + len(corrected_patch))

    final_boundary_i = join_boundaries(orig_boundary_i, corrected_patch, patch_corrected_x_range)
    final_boundaries.append(final_boundary_i)

    proc_img = preprocess_bscan(vol.b_scans[:, :, slice_idx])

    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.imshow(proc_img, cmap='gray')
    plt.plot(np.arange(len(orig_boundary_i)), orig_boundary_i, 'g-', linewidth=2, label="Original")
    plt.plot(np.arange(len(final_boundary_i)), final_boundary_i, 'r-', linewidth=2, label="Updated")
    plt.title(f"Slice {slice_idx+1} w/ Updated Boundary")
    #plt.gca().invert_yaxis()
    plt.legend()

    plt.subplot(1,2,2)
    plt.imshow(proc_img, cmap='gray')
    plt.plot(np.arange(len(orig_boundary_i)), orig_boundary_i, 'g-', linewidth=2)
    plt.title(f"Slice {slice_idx+1} Original")
    #plt.gca().invert_yaxis()
    plt.legend()
    plt.show()

print("Merged + plotted corrected boundaries for all slices.")
