<a href="https://colab.research.google.com/github/alexander-harmaty/Breast-Cancer-Prognosis-Prediction/blob/main/Fusion_layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Environment Setup and Imports**

In [84]:
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from collections import Counter
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout

# **Data Preprocessing**


## *Data Loading*

In [85]:
# LOAD both pickles
with open('rnn_features.pkl','rb') as f:
    rnn = pickle.load(f)
with open('cnn_features.pkl','rb') as f:
    cnn = pickle.load(f)

df_rnn = pd.DataFrame(rnn['features'], index=rnn['ids'])
df_rnn['label'] = rnn['labels']
df_cnn = pd.DataFrame(cnn['features'], index=cnn['ids'])
df_cnn['label'] = cnn['labels']

# ALIGN on common IDs
common = df_rnn.index.intersection(df_cnn.index)
print(f"Found {len(common)} common patient IDs.")
if len(common)==0:
    raise RuntimeError("No overlapping IDs!")

df_rnn = df_rnn.loc[common]
df_cnn = df_cnn.loc[common]

# FIND mismatches
mismatch = df_rnn['label'] != df_cnn['label']
if mismatch.any():
    mids = df_rnn.index[mismatch].tolist()
    print(f"⚠️  Found {len(mids)} label mismatches:", mids)
    display(pd.DataFrame({
      'rnn_label': df_rnn.loc[mids,'label'],
      'cnn_label': df_cnn.loc[mids,'label'],
    }))
    # override RNN → CNN
    print("→ overriding RNN labels to match CNN for those IDs")
    df_rnn.loc[mismatch,'label'] = df_cnn.loc[mismatch,'label']

assert (df_rnn['label']==df_cnn['label']).all()

# BUILD fusion matrices
X_rnn = df_rnn.drop(columns='label').values
X_cnn = df_cnn.drop(columns='label').values
y     = df_cnn['label'].values
print("Fused data shapes:", X_rnn.shape, X_cnn.shape, y.shape)

Found 100 common patient IDs.
⚠️  Found 15 label mismatches: ['Breast_MRI_018', 'Breast_MRI_054', 'Breast_MRI_124', 'Breast_MRI_141', 'Breast_MRI_246', 'Breast_MRI_280', 'Breast_MRI_290', 'Breast_MRI_339', 'Breast_MRI_466', 'Breast_MRI_595', 'Breast_MRI_723', 'Breast_MRI_726', 'Breast_MRI_797', 'Breast_MRI_805', 'Breast_MRI_832']


Unnamed: 0,rnn_label,cnn_label
Breast_MRI_018,0.0,1.0
Breast_MRI_054,0.0,1.0
Breast_MRI_124,1.0,0.0
Breast_MRI_141,0.0,1.0
Breast_MRI_246,1.0,0.0
Breast_MRI_280,0.0,1.0
Breast_MRI_290,1.0,0.0
Breast_MRI_339,1.0,0.0
Breast_MRI_466,1.0,0.0
Breast_MRI_595,1.0,0.0


→ overriding RNN labels to match CNN for those IDs
Fused data shapes: (100, 32) (100, 128) (100,)


## *Feature Merging*
single-input vs dual-input (choose which one to run)

plan:

*   single‑input fusion for baseline
*   dual‑input fusion for fine tuning




### single-input

In [86]:
# MERGE features (single‑input fusion)
X = np.concatenate([X_cnn, X_rnn], axis=1)
print("Fused feature shape:", X.shape)
print("Post‑fusion label distribution:", Counter(y))

Fused feature shape: (100, 160)
Post‑fusion label distribution: Counter({np.float32(0.0): 95, np.float32(1.0): 5})


## *Data Splitting*

In [87]:
# FIRST SPLIT (train vs temp)
strat1 = y if len(set(y))>1 else None
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y,
    test_size=0.30,
    random_state=42,
    stratify=strat1
)

# SECOND SPLIT (val vs test) with pre-check
counts_temp = Counter(y_temp)
# if any class has <2 samples, drop stratify
if any(c<2 for c in counts_temp.values()):
    print("Too few examples in one class for stratified split, splitting WITHOUT stratify")
    strat2 = None
else:
    strat2 = y_temp

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=0.50,
    random_state=42,
    stratify=strat2
)

print("\nSplit sizes:")
print(" Train:", X_train.shape, Counter(y_train))
print(" Val:  ", X_val.shape,   Counter(y_val))
print(" Test: ", X_test.shape,  Counter(y_test))

Too few examples in one class for stratified split, splitting WITHOUT stratify

Split sizes:
 Train: (70, 160) Counter({np.float32(0.0): 66, np.float32(1.0): 4})
 Val:   (15, 160) Counter({np.float32(0.0): 15})
 Test:  (15, 160) Counter({np.float32(0.0): 14, np.float32(1.0): 1})


# **Fusion Model**

## *Model Building*

In [88]:
# BUILD model
model = Sequential([
    tf.keras.Input(shape=(X.shape[1],)),
    Dense(128, activation='relu'),
    Dropout(0.3),
    Dense(64,  activation='relu'),
    Dropout(0.3),
    Dense(1,   activation='sigmoid'),
])
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)
model.summary()

## *Model Training*

In [89]:
# TRAIN
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=32,
    verbose=2
)

Epoch 1/50
3/3 - 2s - 658ms/step - accuracy: 0.3714 - auc: 0.5682 - loss: 0.7292 - val_accuracy: 0.9333 - val_auc: 0.0000e+00 - val_loss: 0.5633
Epoch 2/50
3/3 - 0s - 107ms/step - accuracy: 0.8286 - auc: 0.6761 - loss: 0.5436 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_loss: 0.4144
Epoch 3/50
3/3 - 0s - 44ms/step - accuracy: 0.9286 - auc: 0.5189 - loss: 0.4751 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_loss: 0.3051
Epoch 4/50
3/3 - 0s - 33ms/step - accuracy: 0.9429 - auc: 0.6402 - loss: 0.3672 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_loss: 0.2280
Epoch 5/50
3/3 - 0s - 46ms/step - accuracy: 0.9429 - auc: 0.7917 - loss: 0.2892 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_loss: 0.1735
Epoch 6/50
3/3 - 0s - 33ms/step - accuracy: 0.9429 - auc: 0.6288 - loss: 0.2685 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_loss: 0.1308
Epoch 7/50
3/3 - 0s - 34ms/step - accuracy: 0.9429 - auc: 0.9072 - loss: 0.2195 - val_accuracy: 1.0000 - val_auc: 0.0000e+00 - val_los

## *Model Evaluation*

In [90]:
# EVALUATE
test_loss, test_acc, test_auc = model.evaluate(X_test, y_test, verbose=0)
print(f"\nTest Loss: {test_loss:.3f}  Accuracy: {test_acc:.3f}  AUC: {test_auc:.3f}")

y_pred_prob = model.predict(X_test).ravel()
y_pred      = (y_pred_prob>0.5).astype(int)

print("\nTest labels distribution:", Counter(y_test))
print("\nConfusion matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification report:\n", classification_report(y_test, y_pred))


Test Loss: 0.323  Accuracy: 0.933  AUC: 0.679
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step

Test labels distribution: Counter({np.float32(0.0): 14, np.float32(1.0): 1})

Confusion matrix:
 [[14  0]
 [ 1  0]]

Classification report:
               precision    recall  f1-score   support

         0.0       0.93      1.00      0.97        14
         1.0       0.00      0.00      0.00         1

    accuracy                           0.93        15
   macro avg       0.47      0.50      0.48        15
weighted avg       0.87      0.93      0.90        15



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
