In [None]:
from losses import *
import layer_util
import sys
sys.path.append('..')
from multiflowseg_utils import *

In [None]:
model_name = 'FLOW-3'

In [None]:
image_size = 128
frames = 32
data_path = f'../data/clean_{image_size}_{frames}'

with open('patients.json', 'r') as json_file:
    patients = json.load(json_file)

train_patients, val_patients, test_patients = patients['train'],patients['val'],patients['test']

patients = train_patients +  val_patients + test_patients 
len(train_patients), len(val_patients), len(test_patients) 

In [None]:
def get_volumentation(image_size, frames, vessel):
    transforms = [
        V.RandomBrightnessContrast(p=0.5),
        V.Flip(axis=1, p=0.4),
        V.Rotate((0, 0), (0, 0), (-45, 45), p=0.5)
    ]

    crop_val = random.randint(25, 50)
    pad_factor = random.randint(0, 20)

    if random.random() < 0.5:
        transforms.append(V.PadIfNeeded(
            shape=(image_size + pad_factor, image_size + pad_factor, frames), p=0.3))
    else:
        transforms.append(V.CenterCrop(
            shape=(image_size - crop_val, image_size - crop_val, frames), p=0.3))

    transforms.append(V.Resize(shape=(image_size, image_size, frames), p=1.0))
    return V.Compose(transforms, p=1.0)


In [None]:
# Load vencs and series descriptions
venc_df = pd.read_csv('venc.csv')
series_description_df = pd.read_csv('seriesdescription.csv').set_index(['patient','vessel'])

class CustomDataGen:
    def __init__(self, patients, cohort, vessel=''):
        self.patients = patients
        self.cohort = cohort
        self.vessel = vessel

    def data_generator(self):
        vessel_indices = (
            list(vessels_dict.keys())[1:] if self.cohort != 'test' else [self.vessel]
        )
        num_vessels = len(vessels_dict)

        for patient in self.patients:
            for vessel in vessel_indices:
                vessel_index = vessels_dict[vessel]

                # Load data
                mag_image, phase_image, mask = np.load(
                    f'{data_path}/{patient}_{vessel}.npy', allow_pickle=True
                )
                
                # Preprocess magnitude image
                mag_image[mag_image < 1e-10] = 0
                mag_image = (mag_image - np.min(mag_image)) / (np.max(mag_image))
                mag_image[mag_image >= 1] = 1

                # Normalize phase
                max_val = np.max(phase_image)
                phase_image = phase_image.astype('float32') / max_val

                # Binarize and cast mask
                mask = (mask > 0.5).astype('uint8')

                # Get VENC and compute angles
                venc = venc_df.loc[
                    (venc_df['patient'] == patient) & (venc_df['vessel'] == vessel)
                ].venc.values[0]
                angles = phase2angle(phase_image, venc)

                # Data augmentation (only for training)
                if self.cohort == 'train':
                    mask_phase = np.stack([mask, phase_image], -1)
                    aug = get_volumentation(image_size, frames, vessel)
                    aug_data = aug(image=mag_image, mask=mask_phase)
                    mag_image, mask_phase = aug_data['image'], aug_data['mask']
                    mask, phase_image = mask_phase[..., 0], mask_phase[..., 1]
                    angles = phase2angle(phase_image, venc)

                # Equalize, compute complex image
                mag_image = skimage.exposure.equalize_adapthist(mag_image)
                complex_image = create_complex_image(mag_image, angles)
                real_image, imaginary_image = complex_image[..., 0], complex_image[..., 1]

                # Random sign flip for imaginary part
                if self.cohort == 'train' and random.random() < 0.5:
                    imaginary_image = -imaginary_image

                # Normalize inputs
                mag_image = normalise(mag_image)
                imaginary_image = normalise(imaginary_image)
                phase_image = normalise(phase_image)

                # Construct input tensor
                X = np.stack([mag_image, imaginary_image], -1)

                # One-hot encode mask
                one_hot_mask = np.zeros((image_size, image_size, frames, num_vessels), dtype='uint8')
                one_hot_mask[..., 0] = (mask == 0).astype('uint8')  # Background
                one_hot_mask[..., vessel_index] = mask
                y = one_hot_mask

                # Vessel one-hot encoding
                cgm_input = tf.one_hot(vessel_index, len(vessels_dict))

                if self.cohort == 'test':
                    # Extract and process test labels from description
                    description = series_description_df.loc[patient, vessel].seriesdescription
                    description = description.replace('_',' ').replace('.',' ').replace('x','').replace('  ',' ').split(' ')


                    labels = is_token_a_substring_in_dictionary(data_dictionary, description) 
                    if len(labels) == 0:
                        label = 0
                    else:
                        labels = pd.Series(labels)
                        if (labels == 'other').any():
                            label = 'other'
                        else:
                            label = labels.value_counts().index[0]
                    
                    one_hot = vessels_dict[label] if label in vessels_dict.keys() else 0 # tunable input 
                    one_hot_input = tf.one_hot(one_hot, 6)[np.newaxis] 

                else:
                    if self.cohort == 'train' and random.random() < 0.05:
                        one_hot_input = tf.one_hot(random.randint(0, 5), len(vessels_dict)) # augment the tunable series description input
                    else:
                        one_hot_input = cgm_input

                yield {
                    'image_input': X.astype('float32'),
                    'cgm_input': cgm_input,
                    'one_hot_input': one_hot_input,
                    'mask_input': y.astype('uint8')
                }, y

    def get_gen(self):
        return self.data_generator()


In [None]:
# Input and output shapes
input_channel = 2
out_channels = len(vessels_dict)

input_shape = [image_size, image_size, frames, input_channel]
output_shape = [image_size, image_size, frames, out_channels]

# Data generators
train_gen = CustomDataGen(train_patients, mode='train').get_gen
val_gen   = CustomDataGen(val_patients, mode='val').get_gen

# Dataset output structure
output_signature = (
    {
        'image_input': tf.TensorSpec(shape=input_shape, dtype=tf.float32),
        'cgm_input': tf.TensorSpec(shape=[6], dtype=tf.uint8),
        'one_hot_input': tf.TensorSpec(shape=[6], dtype=tf.uint8),
        'mask_input': tf.TensorSpec(shape=output_shape, dtype=tf.uint8),
    },
    tf.TensorSpec(shape=output_shape, dtype=tf.uint8)
)

# Create tf.data.Dataset objects
train_ds = tf.data.Dataset.from_generator(train_gen, output_signature=output_signature)
val_ds   = tf.data.Dataset.from_generator(val_gen, output_signature=output_signature)

# Shuffle, batch, and prefetch
BATCH_SIZE = 8
train_ds = (
    train_ds
    .shuffle(buffer_size=max(1, len(train_patients) // BATCH_SIZE), seed=42, reshuffle_each_iteration=True)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)


In [None]:
# ==== Hyperparameters ====
rank = 3
n_outputs = 6
add_dropout = True
dropout_rate = 0.3
base_filters = 16
kernel_size = 3
stack_num_down = 3
stack_num_up = 1
batch_norm = 1
CGM = True
supervision = True
upsamp_type = 'UpSampling'

conv_config = dict(kernel_size=3, padding='same', kernel_initializer='he_normal')

# ==== Utility Blocks ====
def conv_block(inputs, filters, num_stacks):
    conv = layer_util.get_nd_layer('Conv', rank)
    x = inputs
    for _ in range(num_stacks):
        x = conv(filters, **conv_config)(x)
        if batch_norm:
            x = BatchNormalization(axis=-1)(x)
        x = LeakyReLU()(x)
    return x

def encode(inputs, scale, num_stacks):
    maxpool = layer_util.get_nd_layer('MaxPool', rank)
    scale -= 1
    filters = base_filters * 2 ** scale
    filters = filters - 1 if scale == 4 else filters

    x = inputs
    if scale != 0:
        pool_size = (2, 2) if rank == 2 else (2, 2, 1)
        x = maxpool(pool_size=pool_size, name=f'encoding_{scale}_maxpool')(x)
    x = conv_block(x, filters, num_stacks)
    return x

def full_scale(inputs, to_layer, from_layer):
    conv = layer_util.get_nd_layer('Conv', rank)
    maxpool = layer_util.get_nd_layer('MaxPool', rank)
    upsamp = layer_util.get_nd_layer(upsamp_type, rank)

    size = 2 ** abs(from_layer - to_layer)
    x = inputs

    if to_layer < from_layer:
        upsamp_config = {'size': (size, size) if rank == 2 else (size, size, 1), 'interpolation': 'bilinear'}
        x = upsamp(**upsamp_config, name=f'fullscale_{from_layer}_{to_layer}')(x)
    elif to_layer > from_layer:
        pool_size = (size, size) if rank == 2 else (size, size, 1)
        x = maxpool(pool_size=pool_size, name=f'fullscale_maxpool_{from_layer}_{to_layer}')(x)

    x = conv_block(x, base_filters, stack_num_up)
    return x

def aggregate(scale_list, name):
    x = concatenate(scale_list, axis=-1)
    return conv_block(x, base_filters * 5, stack_num_up)

def deep_sup(inputs, scale):
    conv = layer_util.get_nd_layer('Conv', rank)
    upsamp = layer_util.get_nd_layer(upsamp_type, rank)
    size = 2 ** (scale - 1)
    x = conv(n_outputs, activation=None, **conv_config, name=f'deepsup_conv_{scale}')(inputs)
    if scale != 1:
        upsamp_config = {'size': (size, size) if rank == 2 else (size, size, 1), 'interpolation': 'bilinear'}
        x = upsamp(**upsamp_config, name=f'deepsup_upsamp_{scale}')(x)
    return x

# ==== Inputs ====
image_input = Input(shape=input_shape, name='image_input')
cgm_input = Input(shape=(6,), name='cgm_input')
one_hot_input = Input(shape=(6,), name='one_hot_input')
mask_input = Input(shape=output_shape, name='mask_input')

# ==== One-hot Feature Embedding ====
rank = 2
bottleneck_hw = 8
T1 = Dense((bottleneck_hw // 2) ** rank)(one_hot_input)
T1 = Dense(bottleneck_hw ** rank)(T1)
T1 = Reshape((bottleneck_hw, bottleneck_hw, 1, 1))(T1)
T1 = Lambda(lambda x: tf.tile(x, [1, 1, 1, 32, 1]))(T1)  # Tile over time

# ==== Encoding Path ====
XE1 = encode(image_input, 1, stack_num_down)
XE2 = encode(XE1, 2, stack_num_down)
XE3 = encode(XE2, 3, stack_num_down)
XE4 = encode(XE3, 4, stack_num_down)
XE5 = encode(XE4, 5, stack_num_down)
XE5 = tf.concat([XE5, T1], axis=-1)

# ==== Classification Guided Module ====
if CGM:
    x_cgm = SpatialDropout3D(rate=0.5)(XE5)
    x_cgm = layer_util.get_nd_layer('Conv', rank)(n_outputs, kernel_size=(1, 1, 1), padding='same', strides=(1, 1, 1))(x_cgm)
    x_cgm = GlobalMaxPooling3D()(x_cgm)

    if n_outputs == 1:
        x_cgm = tf.keras.activations.sigmoid(x_cgm)
    else:
        x_cgm = tf.keras.activations.softmax(x_cgm)
        cgm_output = x_cgm
        vessel_probs = tf.gather(x_cgm, [1, 2, 3, 4, 5], axis=-1)
        max_vessel_indices = tf.argmax(vessel_probs, axis=-1, output_type=tf.int32)
        one_hot_mask = tf.one_hot(max_vessel_indices, depth=5, axis=-1)
        bkg = tf.ones_like(one_hot_mask)[:, :1]
        x_cgm = tf.concat([bkg, one_hot_mask], axis=-1)
        x_cgm = Reshape((1, 1, 1, 6))(x_cgm)

# ==== Decoding Path with Deep Supervision ====
def decoder_block(from_list, target_level):
    return aggregate([full_scale(x, target_level, idx) for idx, x in enumerate(from_list, start=1)], name=f'agg_XD{target_level}')

XD4 = decoder_block([XE5, XE4, XE3, XE2, XE1], 4)
XD3 = decoder_block([XE5, XD4, XE3, XE2, XE1], 3)
XD2 = decoder_block([XE5, XD4, XD3, XE2, XE1], 2)
XD1 = decoder_block([XE5, XD4, XD3, XD2, XE1], 1)

if supervision:
    XD5_out = Activation('softmax', name='output5')(deep_sup(XE5, 5))
    XD4_out = Activation('softmax', name='output4')(deep_sup(XD4, 4))
    XD3_out = Activation('softmax', name='output3')(deep_sup(XD3, 3))
    XD2_out = Activation('softmax', name='output2')(deep_sup(XD2, 2))
    XD1_out = Activation('softmax', name='output1')(deep_sup(XD1, 1))
else:
    XD1_out = Activation('softmax', name='output1')(deep_sup(XD1, 1))

# Multiply Segmentation with Classification to produce one channel output
if CGM:
    for x in [XD1_out, XD2_out, XD3_out, XD4_out, XD5_out]:
        x *= x_cgm

outputs = [XD5_out, XD4_out, XD3_out, XD2_out, XD1_out] if supervision else XD1_out

# ==== Model & Loss ====
model = tf.keras.Model(inputs=[image_input, cgm_input, one_hot_input, mask_input], outputs=outputs)

# Add losses
focal_loss = 0
for i, output in enumerate(outputs):
    loss = focal_tversky_loss(mask_input, output)
    model.add_metric(loss, name=f'output{5-i}_loss', aggregation='mean')
    focal_loss += loss

cgm_loss = categorical_crossentropy(cgm_input, cgm_output)
model.add_metric(cgm_loss, name='cgm_loss', aggregation='mean')
model.add_loss(focal_loss)
model.add_loss(0.25 * cgm_loss)
model.add_metric(focal_loss * cgm_loss, name='cgm_focal_loss', aggregation='mean')

model.compile(
    optimizer=Adam(),
    loss=None,
    loss_weights=[0.25, 0.25, 0.25, 0.25, 1] if supervision else None
)

model.summary()

In [None]:
def evaluate():
    results = []
    best_model = tf.keras.models.load_model(f'models/{model_name}.h5', compile=False)
    dices = []
    for patient in test_patients:
        for vessel in list(vessels_dict.keys())[1:]:
            vessel_index = vessels_dict[vessel]

            # Load test data
            X_test, y_test = [], []
            test_gen = CustomDataGen([patient], 'test', vessel).get_gen
            test_ds = tf.data.Dataset.from_generator(test_gen, output_signature=output_signature)

            for X, y in test_gen():
                X_test.append(X['image_input'])
                y_test.append(y)

            X_test = np.stack(X_test)
            y_test = np.stack(y_test)
            test_ds = test_ds.batch(1).prefetch(tf.data.AUTOTUNE)

            # Prediction
            y_pred = best_model.predict(test_ds)
            if isinstance(y_pred, list):
                y_pred = y_pred[-1]
            y_pred = get_one_hot(np.argmax(y_pred, axis=-1), out_channels).astype('uint8')

            image = X_test[0, ..., 0]
            pred_mask = y_pred[0, ..., vessel_index]
            true_mask = y_test[0, ..., vessel_index]

            # Save mask output
            mask_dir = Path(f'results/{model_name}/masks/')
            mask_dir.mkdir(parents=True, exist_ok=True)
            np.save(mask_dir / f'{patient}_{vessel}.npy', y_pred[0])

            # Calculate Dice
            dice_val = single_dice(true_mask, pred_mask)
            dices.append(dice_val)
            results.append({'patient': patient, 'vessel': vessel, 'dice': dice_val})

            # Plot prediction GIF
            fig, axs = plt.subplots(1, 2, figsize=(9, 5))
            fig.suptitle(f'Dice = {dice_val:.2f}')
            frames = []

            for i in range(image.shape[-1]):
                p1 = axs[0].imshow(image[..., i], cmap='gray', vmin=image.min(), vmax=image.max())
                p2 = axs[1].imshow(image[..., i], cmap='gray', vmin=image.min(), vmax=image.max())
                p3 = axs[1].imshow(true_mask[..., i], alpha=true_mask[..., i] * 0.7, cmap=colormaps[vessel])
                text = axs[0].text(0, -5, f'Time = {i}')

                artists = [p1, p2, p3, text]
                for j, label in enumerate(list(colormaps.keys())[1:]):
                    cmap = colormaps[label]
                    if np.sum(y_pred[0, ..., j + 1]) > 0:
                        artists.append(axs[0].imshow(y_pred[0, ..., i, j + 1], alpha=y_pred[0, ..., i, j + 1] * 0.7, cmap=cmap))

                frames.append(artists)

            legend_patches = [mpatches.Patch(color=plt.cm.get_cmap(colormaps[label])(0.5), label=label) for label in colormaps.keys()]
            fig.legend(handles=legend_patches, loc='lower center', ncol=5, fontsize='large', bbox_to_anchor=(0.5, 0))
            fig.tight_layout()
            plt.subplots_adjust(hspace=0.5, bottom=0.1)

            ani = animation.ArtistAnimation(fig, frames)
            gif_path = Path(f'results/{model_name}/{vessel}')
            gif_path.mkdir(parents=True, exist_ok=True)
            gif_file = gif_path / f'{patient}.gif'
            ani.save(str(gif_file), fps=image.shape[0] / 2)

            plt.close()


    # Save Dice results to CSV
    df = pd.DataFrame(results)
    df.to_csv(f'results/segmentation_{model_name}.csv', index=False)
    
    
# Custom Keras callback for periodic evaluation
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, counter=0, save_every=10):
        super().__init__()
        self.save_every = save_every
        self.counter = counter

    def on_epoch_end(self, epoch, logs=None):
        self.counter += 1
        if self.counter % self.save_every == 0:
            evaluate()

    def on_train_end(self, logs=None):
        evaluate()


In [None]:
# Model checkpoint callback
mc = ModelCheckpoint(
    filepath=f'models/{model_name}.h5',
    save_best_only=True,
    monitor='val_output1_loss',
    mode='min'
)

# Custom evaluation callback
eval_every_epoch = CustomCallback(save_every=200)

# Model training
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=400,
    callbacks=[mc, eval_every_epoch]
)