In [3]:
#Early fusion 
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, roc_curve, classification_report
from pytorch_tabnet.tab_model import TabNetClassifier
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from medclip import MedCLIPModel  # Make sure this import works
from PIL import Image
import torch.nn as nn

# === Ensure results folder exists ===
os.makedirs("results", exist_ok=True)

# === Label mapping for multiclass classification ===
label_map = {'normal': 0, 'benign': 1, 'malignant': 2}

# === Custom Dataset for Ultrasound Imagery ===
class UltrasoundDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert('RGB')  # ensure 3 channels
        if self.transform:
            image = self.transform(image)
        label = self.labels[index]
        return image, label

    def __len__(self):
        return len(self.image_paths)

# === Load Clinical + Radiometric Data ===
clinical_df = pd.read_csv("preprocessed_text_data.csv")

# Normalize clinical Image_id column: lowercase and remove file extensions
clinical_df['Image_id_clean'] = clinical_df['Image_id'].astype(str).str.lower().str.replace(r'\.png$', '', regex=True)

# --- Match images to clinical data by Image_id ---
image_dir = "BrEaST-Lesions_USG-images_and_masks"
image_extension = ".png"  # Adjust if your images are ".jpg" or others

# Get all filenames (lowercase) in the image directory
valid_filenames = set(f.lower() for f in os.listdir(image_dir) if f.lower().endswith(image_extension))

matched_image_paths = []
matched_labels = []

for idx, row in clinical_df.iterrows():
    img_id = str(row['Image_id_clean']).lower()
    img_filename = img_id + image_extension
    if img_filename in valid_filenames:
        matched_image_paths.append(os.path.join(image_dir, img_filename))
        matched_labels.append(label_map[row['Classification']])
    else:
        print(f"Warning: Image file not found for: {img_filename}")

matched_labels = np.array(matched_labels)

print(f"Total matched images: {len(matched_image_paths)}")
print(f"Total matched labels: {len(matched_labels)}")

# === Prepare image dataset and dataloader ===
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

image_dataset = UltrasoundDataset(matched_image_paths, matched_labels, transform)
image_loader = DataLoader(image_dataset, batch_size=16, shuffle=False)

# === Prepare clinical data matrix for matched images only ===
# Extract image IDs without extension from matched_image_paths
matched_ids = [os.path.splitext(os.path.basename(p))[0].lower() for p in matched_image_paths]

# Filter clinical_df using the cleaned Image_id column
clinical_matched_df = clinical_df[clinical_df['Image_id_clean'].isin(matched_ids)]

# Drop non-numeric columns and those not needed
clinical_features = clinical_matched_df.drop(
    columns=['Classification', 'Image_id', 'CaseID', 'Mask_tumor_filename', 'Mask_other_filename', 'Image_id_clean']
).select_dtypes(include=[np.number]).values

# === Normalize clinical features ===
scaler_clinical = StandardScaler()
clinical_features_scaled = scaler_clinical.fit_transform(clinical_features)

# === Load MedCLIP model ===
medclip = MedCLIPModel()
medclip.load_state_dict(torch.load("model_weights.pth", map_location="cpu"))
medclip.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
medclip = medclip.to(device)

# === Extract Image Embeddings ===
image_embeddings = []
image_labels = []

with torch.no_grad():
    for imgs, lbls in image_loader:
        imgs = imgs.to(device)
        emb = medclip.encode_image(imgs)
        image_embeddings.append(emb.cpu().numpy())
        if isinstance(lbls, torch.Tensor):
            lbls_array = lbls.cpu().numpy()
        else:
            lbls_array = np.array(lbls)
        image_labels.extend(lbls_array)

image_embeddings = np.vstack(image_embeddings)

# === Normalize image embeddings ===
scaler_image = StandardScaler()
image_embeddings_scaled = scaler_image.fit_transform(image_embeddings)

y_images = np.array(image_labels)

# === Combine Modalities ===
X_combined = np.hstack([image_embeddings_scaled, clinical_features_scaled])
y_combined = y_images  # labels from matched dataset

# === Split combined data with stratification ===
X_train_comb, X_test_comb, y_train_comb, y_test_comb = train_test_split(
    X_combined, y_combined, test_size=0.2, random_state=42, stratify=y_combined)

# === Train TabNet on Combined Data ===
tabnet = TabNetClassifier(
    optimizer_params=dict(lr=2e-2),
    scheduler_params={"step_size":50, "gamma":0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type='entmax'
)

tabnet.fit(
    X_train_comb, y_train_comb,
    eval_set=[(X_test_comb, y_test_comb)],
    max_epochs=50000,
    patience=30,
    batch_size=32,
    virtual_batch_size=16,
    num_workers=0,
    drop_last=False
)

# === Predict and Evaluate ===
y_pred = tabnet.predict(X_test_comb)
y_proba = tabnet.predict_proba(X_test_comb)

accuracy = accuracy_score(y_test_comb, y_pred)
f1 = f1_score(y_test_comb, y_pred, average='weighted')
auc = roc_auc_score(y_test_comb, y_proba, multi_class='ovr')

conf_matrix = confusion_matrix(y_test_comb, y_pred)

print("Accuracy:", accuracy)
print("F1 Score (weighted):", f1)
print("AUC-ROC (OvR):", auc)

# === Save Metrics and Results ===
metrics_df = pd.DataFrame({
    'accuracy': [accuracy],
    'f1_score_weighted': [f1],
    'auc_roc_ovr': [auc]
})
metrics_df.to_csv("results/metrics.csv", index=False)

conf_df = pd.DataFrame(conf_matrix,
                       index=[f"Actual {k}" for k in label_map.values()],
                       columns=[f"Predicted {k}" for k in label_map.values()])
conf_df.to_csv("results/confusion_matrix.csv")

report_dict = classification_report(y_test_comb, y_pred, target_names=list(label_map.keys()), output_dict=True)
report_df = pd.DataFrame(report_dict).transpose()
report_df.to_csv("results/classification_report.csv")

plt.figure(figsize=(6, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False, square=True,
            xticklabels=list(label_map.keys()), yticklabels=list(label_map.keys()))
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('True', fontsize=12)
plt.title('Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.savefig("results/confusion_matrix.png", dpi=300)
plt.close()

plt.figure(figsize=(8, 8))
for i, class_name in enumerate(label_map.keys()):
    fpr, tpr, _ = roc_curve(y_test_comb == i, y_proba[:, i])
    plt.plot(fpr, tpr, label=f"{class_name} (AUC = {roc_auc_score(y_test_comb == i, y_proba[:, i]):.2f})")
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title("ROC Curve (One-vs-Rest)", fontsize=14)
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig("results/roc_curve.png", dpi=300)
plt.close()

importance = tabnet.feature_importances_
plt.figure(figsize=(12, 6))
plt.bar(range(len(importance)), importance, color='teal')
plt.title("Feature Importance (TabNet)", fontsize=14)
plt.xlabel("Feature Index", fontsize=12)
plt.ylabel("Importance", fontsize=12)
plt.tight_layout()
plt.savefig("results/feature_importance.png", dpi=300)
plt.close()
    

Total matched images: 256
Total matched labels: 256




ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/models/emilyalsentzer/Bio_ClinicalBERT/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001957B42ACF0>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: d18c9cc0-b59f-4f95-a8c5-5b632b3888ff)')

In [2]:
import joblib
joblib.dump(tabnet, "results/tabnet_combined_model.joblib")

['results/tabnet_combined_model.joblib']

In [4]:
import joblib
# Save the clinical scaler fitted in the notebook
joblib.dump(scaler_clinical, "results/scaler_clinical.joblib")
# Save the image embedding scaler fitted in the notebook
joblib.dump(scaler_image, "results/scaler_image.joblib")
# Ensure the model weights exist
torch.save(medclip.state_dict(), "results/Medclip_model_weights.pth")