In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, LayerNormalization, MultiHeadAttention, Add
from tensorflow.keras.optimizers import AdamW
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import seaborn as sns

# -------------------------------
# PARAMETERS (OPTIMIZED FOR MEMORY)
# -------------------------------
img_height, img_width = 224, 224
batch_size = 16  # REDUCED from 32 to 16
num_classes = 2
epochs = 20

# Optimized Multiscale Parameters - removed 4x4 patches
patch_sizes = [16, 8]  # CHANGED: removed 4 to reduce memory
depths = [2, 4]  # CHANGED: adjusted depths
dims = [96, 192]  # CHANGED: adjusted dimensions
num_heads_list = [3, 6]  # CHANGED: adjusted heads
mlp_dims = [384, 768]  # CHANGED: adjusted MLP dims
dropout_rate = 0.1

# -------------------------------
# DEFINE DIRECTORIES FOR ALL CANCER TYPES
# -------------------------------
malignant_dirs = [
    r"E:\LY Project\Multi Cancer\Data\Malignant\all_early",
    r"E:\LY Project\Multi Cancer\Data\Malignant\all_pre",
    r"E:\LY Project\Multi Cancer\Data\Malignant\all_pro",
    r"E:\LY Project\Multi Cancer\Data\Malignant\breast_malignant",
    r"E:\LY Project\Multi Cancer\Data\Malignant\colon_aca",
    r"E:\LY Project\Multi Cancer\Data\Malignant\lung_aca",
    r"E:\LY Project\Multi Cancer\Data\Malignant\lung_scc",
    r"E:\LY Project\Multi Cancer\Data\Malignant\oral_scc"
]

benign_dirs = [
    r"E:\LY Project\Multi Cancer\Data\Benign\all_benign",
    r"E:\LY Project\Multi Cancer\Data\Benign\breast_benign",
    r"E:\LY Project\Multi Cancer\Data\Benign\colon_bnt",
    r"E:\LY Project\Multi Cancer\Data\Benign\lung_bnt",
    r"E:\LY Project\Multi Cancer\Data\Benign\oral_normal"
]

# -------------------------------
# LOAD DATA FROM MULTIPLE DIRECTORIES
# -------------------------------
print("Loading data from multiple directories...")
print("="*60)

def load_images_from_directories(directories, label):
    image_paths = []
    labels = []
    
    for directory in directories:
        if not os.path.exists(directory):
            print(f"âš  Warning: Directory not found - {directory}")
            continue
        
        files = [os.path.join(directory, f) for f in os.listdir(directory)
                if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'))]
        
        folder_name = os.path.basename(directory)
        print(f"  {folder_name}: {len(files)} images")
        
        image_paths.extend(files)
        labels.extend([label] * len(files))
    
    return image_paths, labels

print("\nMALIGNANT Cancer Types:")
malignant_paths, malignant_labels = load_images_from_directories(malignant_dirs, label=1)

print("\nBENIGN Cancer Types:")
benign_paths, benign_labels = load_images_from_directories(benign_dirs, label=0)

all_image_paths = malignant_paths + benign_paths
all_labels = malignant_labels + benign_labels

print(f"\n{'='*60}")
print(f"DATASET SUMMARY")
print(f"{'='*60}")
print(f"Total images: {len(all_image_paths)}")
print(f"  Malignant: {len(malignant_paths)} ({100*len(malignant_paths)/len(all_image_paths):.1f}%)")
print(f"  Benign: {len(benign_paths)} ({100*len(benign_paths)/len(all_image_paths):.1f}%)")
print(f"{'='*60}\n")

# -------------------------------
# CREATE TF.DATA.DATASET
# -------------------------------
def load_and_preprocess_image(path, label):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [img_height, img_width])
    image = image / 255.0
    label = tf.one_hot(label, depth=num_classes)
    return image, label

AUTOTUNE = tf.data.AUTOTUNE
path_ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_labels))
path_ds = path_ds.shuffle(buffer_size=len(all_image_paths), seed=42)

train_size = int(0.8 * len(all_image_paths))
val_size = len(all_image_paths) - train_size

train_ds = path_ds.take(train_size)
val_ds = path_ds.skip(train_size)

print(f"Training samples: {train_size}")
print(f"Validation samples: {val_size}")
print(f"Batch size: {batch_size}\n")

train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)

val_ds = val_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.batch(batch_size)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

# -------------------------------
# MULTISCALE VISION TRANSFORMER COMPONENTS
# -------------------------------

class PatchEmbedding(layers.Layer):
    """Extract patches and embed them"""
    def __init__(self, patch_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = Dense(embed_dim)
    
    def call(self, x):
        batch_size = tf.shape(x)[0]
        patches = tf.image.extract_patches(
            x,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID"
        )
        patch_dim = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dim])
        patches = self.projection(patches)
        return patches

class TransformerBlock(layers.Layer):
    """Memory-efficient Transformer block"""
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        
        self.norm1 = LayerNormalization(epsilon=1e-6)
        self.attention = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=dim // num_heads,
            dropout=dropout
        )
        self.dropout1 = Dropout(dropout)
        
        self.norm2 = LayerNormalization(epsilon=1e-6)
        self.mlp = keras.Sequential([
            Dense(mlp_dim, activation='gelu'),
            Dropout(dropout),
            Dense(dim),
            Dropout(dropout)
        ])
        self.add1 = Add()
        self.add2 = Add()
    
    def call(self, x, training=False):
        x_norm = self.norm1(x)
        attn_output = self.attention(x_norm, x_norm, training=training)
        attn_output = self.dropout1(attn_output, training=training)
        x = self.add1([x, attn_output])
        
        x_norm = self.norm2(x)
        mlp_output = self.mlp(x_norm, training=training)
        x = self.add2([x, mlp_output])
        
        return x

class MultiScaleViT(keras.Model):
    """Memory-optimized Multiscale Vision Transformer"""
    def __init__(self, patch_sizes, depths, dims, num_heads_list, mlp_dims, 
                 num_classes, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.patch_sizes = patch_sizes
        self.depths = depths
        self.dims = dims
        self.num_heads_list = num_heads_list
        self.mlp_dims = mlp_dims
        self.num_classes = num_classes
        
        # Patch embeddings for each scale
        self.patch_embeddings = [
            PatchEmbedding(patch_sizes[i], dims[i], name=f"patch_embed_{i}")
            for i in range(len(patch_sizes))
        ]
        
        # Pre-compute positional embeddings
        self.pos_embeddings = []
        for i in range(len(patch_sizes)):
            num_patches = ((img_height // patch_sizes[i]) ** 2)
            pos_embed = self.add_weight(
                name=f"pos_embed_{i}",
                shape=(1, num_patches, dims[i]),
                initializer="random_normal",
                trainable=True,
                dtype=tf.float32
            )
            self.pos_embeddings.append(pos_embed)
        
        # Transformer blocks for each scale
        self.transformer_blocks = []
        for i in range(len(patch_sizes)):
            blocks = [
                TransformerBlock(dims[i], num_heads_list[i], mlp_dims[i], dropout)
                for _ in range(depths[i])
            ]
            self.transformer_blocks.append(blocks)
        
        # Classification head
        self.norm = LayerNormalization(epsilon=1e-6)
        self.dropout = Dropout(0.5)
        self.fc1 = Dense(512, activation='gelu')
        self.dropout2 = Dropout(0.3)
        self.fc2 = Dense(256, activation='gelu')
        self.classifier = Dense(num_classes, activation='softmax')
    
    def call(self, x, training=False):
        scale_features = []
        
        # Process at multiple scales
        for scale_idx in range(len(self.patch_sizes)):
            # Extract and embed patches
            x_patches = self.patch_embeddings[scale_idx](x)
            
            # Add positional embeddings
            x_patches = x_patches + self.pos_embeddings[scale_idx]
            
            # Apply transformer blocks
            for block in self.transformer_blocks[scale_idx]:
                x_patches = block(x_patches, training=training)
            
            # Global average pooling
            scale_features.append(tf.reduce_mean(x_patches, axis=1))
        
        # Concatenate features from all scales
        x = tf.concat(scale_features, axis=-1)
        
        # Classification head
        x = self.norm(x)
        x = self.dropout(x, training=training)
        x = self.fc1(x)
        x = self.dropout2(x, training=training)
        x = self.fc2(x)
        x = self.dropout(x, training=training)
        x = self.classifier(x)
        
        return x

# -------------------------------
# BUILD MULTISCALE VIT MODEL
# -------------------------------
print("Building Memory-Optimized Multiscale Vision Transformer...")
print(f"Patch sizes: {patch_sizes}")
print(f"Depths: {depths}")
print(f"Dimensions: {dims}")
print(f"Attention heads: {num_heads_list}")
print(f"Number of patches per scale: {[(img_height//p)**2 for p in patch_sizes]}\n")

model = MultiScaleViT(
    patch_sizes=patch_sizes,
    depths=depths,
    dims=dims,
    num_heads_list=num_heads_list,
    mlp_dims=mlp_dims,
    num_classes=num_classes,
    dropout=dropout_rate
)

# Build model
model.build(input_shape=(None, img_height, img_width, 3))
print("\nModel Architecture Summary:")
model.summary()

# -------------------------------
# COMPILE MODEL
# -------------------------------
print("\nCompiling model with AdamW optimizer...")
optimizer = AdamW(learning_rate=0.0005, weight_decay=0.0001)

model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# -------------------------------
# CALLBACKS
# -------------------------------
checkpoint_filepath = 'best_mvit_cancer_model.weights.h5'

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_filepath, 
    monitor='val_accuracy', 
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

# -------------------------------
# TRAIN MODEL
# -------------------------------
print(f"\nTraining for {epochs} epochs...")
print("="*60)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=[checkpoint, early_stopping, reduce_lr],
    verbose=1
)

print("\n" + "="*60)
print("Training completed.")
print(f"Best model saved at {checkpoint_filepath}")
print("="*60)

# -------------------------------
# PLOT TRAINING HISTORY
# -------------------------------
print("\nPlotting training history...")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history.history['accuracy'], label='Training Accuracy', marker='o')
axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy', marker='s')
axes[0].set_title('Multiscale ViT - Model Accuracy', fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history.history['loss'], label='Training Loss', marker='o')
axes[1].plot(history.history['val_loss'], label='Validation Loss', marker='s')
axes[1].set_title('Multiscale ViT - Model Loss', fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('mvit_training_history.png', dpi=300, bbox_inches='tight')
plt.show()

# -------------------------------
# EVALUATION
# -------------------------------
print("\nLoading best model for evaluation...")
model.load_weights(checkpoint_filepath)

print("\nEvaluating model on validation set...")
val_loss, val_accuracy = model.evaluate(val_ds)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")

print("\nPredicting classes on validation data...")

y_pred_probs = []
y_true = []

for images, labels in val_ds:
    predictions = model.predict(images, verbose=0)
    y_pred_probs.extend(predictions)
    y_true.extend(np.argmax(labels.numpy(), axis=1))

y_pred_probs = np.array(y_pred_probs)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.array(y_true)

class_labels = ['Benign', 'Malignant']

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues', 
    xticklabels=class_labels, 
    yticklabels=class_labels,
    cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix - Multiscale ViT\n(Benign vs Malignant)', 
          fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('mvit_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("CLASSIFICATION REPORT - MULTISCALE VIT")
print("="*60)
print(classification_report(y_true, y_pred, target_names=class_labels, digits=4))

from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

y_pred_probs_positive = y_pred_probs[:, 1]

fpr, tpr, thresholds = roc_curve(y_true, y_pred_probs_positive)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Multiscale ViT', fontsize=14, fontweight='bold')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('mvit_roc_curve.png', dpi=300, bbox_inches='tight')
plt.show()

accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
    y_true, y_pred, average=None, labels=[0, 1]
)

tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)

print("\n" + "="*60)
print("DETAILED METRICS - MULTISCALE VIT")
print("="*60)
print(f"\nAccuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"Precision (Benign): {precision[0]:.4f}")
print(f"Recall (Benign): {recall[0]:.4f}")
print(f"F1-Score (Benign): {f1[0]:.4f}")
print(f"Precision (Malignant): {precision[1]:.4f}")
print(f"Recall (Malignant): {recall[1]:.4f}")
print(f"F1-Score (Malignant): {f1[1]:.4f}")
print(f"Sensitivity: {sensitivity:.4f}")
print(f"Specificity: {specificity:.4f}")

print("\n" + "="*60)
print("MULTISCALE VISION TRANSFORMER COMPLETED!")
print("="*60)
 