In [None]:

import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
from urllib.request import urlopen
from PIL import Image
import timm
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from sklearn.model_selection import train_test_split
import torch.nn as nn
import random
from torch.nn.modules.batchnorm import _BatchNorm
import torchmetrics
import matplotlib.pyplot as plt
import torch.nn.functional as F
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",0)
print(f"Device:\t\t{device}")

In [None]:

class_list=['BRNT','BRID','BRIL','BRLC','BRDC']
params={'image_size':512,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':8,
        'epochs':1000,
        'n_classes':5,
        'data_path':'../../../data/NIPA/',
        'inch':3,
        }
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# ÏÉàÎ°úÏö¥ ÏÖÄÏóê Ï∂îÍ∞ÄÌï¥ÏÑú ÌÖåÏä§Ìä∏Ìï¥Î≥¥ÏÑ∏Ïöî
import torch.nn as nn

# 1. BatchNorm ÏÉÅÌÉú ÌôïÏù∏
def check_batchnorm_stats(model):
    """BatchNorm Î†àÏù¥Ïñ¥Ïùò running_meanÍ≥º running_var ÌôïÏù∏"""
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            if torch.isnan(module.running_mean).any() or torch.isnan(module.running_var).any():
                print(f"NaN detected in {name}")
                print(f"Running mean has NaN: {torch.isnan(module.running_mean).any()}")
                print(f"Running var has NaN: {torch.isnan(module.running_var).any()}")

# 2. BatchNorm Ï¥àÍ∏∞Ìôî Ìï®Ïàò
def reset_batchnorm_stats(model):
    """BatchNormÏùò running statistics Ï¥àÍ∏∞Ìôî"""
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.reset_running_stats()

# 3. ÏïàÏ†ÑÌïú eval Î™®Îìú ÏÑ§Ï†ï
def safe_eval_mode(model):
    """ÏïàÏ†ÑÌïòÍ≤å eval Î™®ÎìúÎ°ú Ï†ÑÌôò"""
    model.eval()
    # BatchNorm Î†àÏù¥Ïñ¥Ïùò momentumÏùÑ ÏõêÎûòÎåÄÎ°ú Î≥µÍµ¨
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, 'backup_momentum'):
            module.momentum = module.backup_momentum



In [None]:

trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,parmas, images,label):
        
        self.images = images
        self.args=parmas
        self.label=label
        
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    
    def __getitem__(self, index):
        image=self.images[index]
        label=self.label[index]
        image = self.trans(image)
        return image,label
    
    def __len__(self):
        return len(self.images)


image_label=[]
image_path=[]
for i in tqdm(range(len(class_list))):
    image_list=glob(params['data_path']+class_list[i]+'/*.jpeg')
    for j in range(len(image_list)):
        image_path.append(image_list[j])
        image_label.append(i)
        
train_images=torch.zeros((len(image_path),params['inch'],params['image_size'],params['image_size']))
for i in tqdm(range(len(image_path))):
    train_images[i]=trans(Image.open(image_path[i]).convert('RGB').resize((params['image_size'],params['image_size'])))
X_train, X_test, y_train, y_test = train_test_split(train_images, image_label, test_size=0.2, random_state=42)
# train_dataset=CustomDataset(params,X_train,F.one_hot(torch.tensor(y_train)).to(torch.int64))
train_dataset=CustomDataset(params,train_images,F.one_hot(torch.tensor(image_label)).to(torch.int64))
val_dataset=CustomDataset(params,X_test,F.one_hot(torch.tensor(y_test)).to(torch.int64))
dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)
val_dataloader=DataLoader(val_dataset,batch_size=1,shuffle=True)

In [None]:
class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('tf_efficientnetv2_xl', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
class custom_model(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(custom_model, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)
        
    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size, 2048, 1, 1)
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits
    
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups
        
def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum
            
import transformers

Feature_Extractor=FeatureExtractor()
model = custom_model(len(class_list),1280,Feature_Extractor)
model = model.to(device)
base_optimizer = torch.optim.AdamW
optimizer = SAM(model.parameters(), base_optimizer, lr=params['lr'])
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=len(class_list)).to(device)


In [None]:
MIN_loss=5000
train_loss_list=[]
val_loss_list=[]
sig=nn.Sigmoid()
model_path='../../../model/NIPA_classification/Breast/'
create_dir(model_path)
val_acc_list=[]
for epoch in range(1000):
    train=tqdm(dataloader)
    count=0
    running_loss = 0.0
    acc_loss=0
    model.train()
    for x, y in train:
        
        y = y.to(device).float()
        count+=1
        x=x.to(device).float()
        enable_running_stats(model)
        optimizer.zero_grad()  # optimizer zero Î°ú Ï¥àÍ∏∞Ìôî
        predict = model(x).to(device)
        cost = F.cross_entropy(predict, y) # cost Íµ¨Ìï®
        cost.backward() # costÏóê ÎåÄÌïú backward Íµ¨Ìï®
        optimizer.first_step(zero_grad=True)
        disable_running_stats(model)
        predict = model(x).to(device)
        cost1 = F.cross_entropy(predict, y) # cost Íµ¨Ìï®
        cost1.backward() # costÏóê ÎåÄÌïú backward Íµ¨Ìï®
        optimizer.second_step(zero_grad=True)
        running_loss += cost.item()

        train.set_description(f"epoch: {epoch+1}/{1000} Step: {count+1} loss : {running_loss/count:.4f}")
    train_loss_list.append((running_loss/count))
#validation
    val=tqdm(val_dataloader)
    count=0
    val_running_loss=0.0
    acc_loss=0
    with torch.no_grad():
        for x, y in val:
            y = y.to(device).float()
            count+=1
            x=x.to(device).float()
            predict = model(x).to(device)
            cost = F.cross_entropy(predict, y) # cost Íµ¨Ìï®
            acc=accuracy(predict.argmax(dim=1),y.argmax(dim=1))
            val_running_loss+=cost.item()
            acc_loss+=acc
            val.set_description(f"Validation epoch: {epoch+1}/{1000} Step: {count+1} loss : {val_running_loss/count:.4f}  accuracy: {acc_loss/count:.4f}")
        val_loss_list.append((val_running_loss/count))
        val_acc_list.append((acc_loss/count).cpu().detach().numpy())
    if epoch%100==5:
        plt.figure(figsize=(10,5))
        plt.subplot(1, 2, 1) 
        plt.title('loss_graph')
        plt.plot(np.arange(epoch+1),train_loss_list,label='train_loss')
        plt.plot(np.arange(epoch+1),val_loss_list,label='validation_loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')

        plt.legend()
        plt.subplot(1, 2, 2)  
        plt.title('acc_graph')
        plt.plot(np.arange(epoch+1),val_acc_list,label='validation_acc')
        plt.xlabel('epoch')
        plt.ylabel('accuracy')

        plt.legend()
        plt.show()
        
        
    if MIN_loss>(val_running_loss/count):
        torch.save(model.state_dict(), f'{model_path}modelEff_v2_XL_SAM_'+str(epoch)+'.pt')
        MIN_loss=(val_running_loss/count)
torch.save(model.state_dict(), f'{model_path}modelEff_v2_XL_SAM.pt')

In [None]:
# Test Set Performance Evaluation
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support, roc_curve, auc
import matplotlib.pyplot as plt
import matplotlib
import scipy.stats as stats
from sklearn.preprocessing import label_binarize
from itertools import cycle
import warnings
warnings.filterwarnings('ignore')

model_path='../../../model/NIPA_classification/Breast/'
matplotlib.rcParams['font.size'] = 10

# Load the best model
# best_model_path = f'{model_path}modelEff_v2_XL_SAM_13.pt'
best_model_path = f'{model_path}modelEff_v2_XL_SAM_127.pt'
model.load_state_dict(torch.load(best_model_path))
model.train()  # Set to evaluation mode

# Test set evaluation
test_predictions = []
test_labels = []
test_probabilities = []

print("Evaluating on Test Set...")
print("="*50)

with torch.no_grad():
    for x, y in tqdm(val_dataloader, desc="Testing"):
        x = x.to(device).float()
        y = y.to(device).float()
        
        outputs = model(x)
        probabilities = F.softmax(outputs, dim=1)
        predictions = torch.argmax(probabilities, dim=1)
        
        test_predictions.extend(predictions.cpu().numpy())
        test_labels.extend(torch.argmax(y, dim=1).cpu().numpy())
        test_probabilities.extend(probabilities.cpu().numpy())

# Convert to numpy arrays
test_predictions = np.array(test_predictions)
test_labels = np.array(test_labels)
test_probabilities = np.array(test_probabilities)

# Calculate basic metrics
accuracy = accuracy_score(test_labels, test_predictions)
precision, recall, f1, support = precision_recall_fscore_support(test_labels, test_predictions, average=None, zero_division=0)
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='macro', zero_division=0)
weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(test_labels, test_predictions, average='weighted', zero_division=0)

# Get confusion matrix
cm = confusion_matrix(test_labels, test_predictions)

print("\nüìä Basic Test Set Performance Results")
print("="*50)
print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Macro Average - Precision: {macro_precision:.4f}, Recall: {macro_recall:.4f}, F1-Score: {macro_f1:.4f}")
print(f"Weighted Average - Precision: {weighted_precision:.4f}, Recall: {weighted_recall:.4f}, F1-Score: {weighted_f1:.4f}")

print("\nüîÑ Confusion Matrix:")
print("-"*30)
cm_df = pd.DataFrame(cm, index=class_list, columns=class_list)
print(cm_df.to_string())

# Calculate per-class metrics with confidence intervals
def calculate_wilson_ci(successes, trials, confidence=0.95):
    """Calculate Wilson confidence interval for binomial proportion"""
    if trials == 0:
        return 0, 0, 0
    
    p = successes / trials
    z = stats.norm.ppf((1 + confidence) / 2)
    
    denominator = 1 + z**2 / trials
    centre = (p + z**2 / (2 * trials)) / denominator
    delta = z * np.sqrt((p * (1 - p) + z**2 / (4 * trials)) / trials) / denominator
    
    ci_lower = max(0, centre - delta)
    ci_upper = min(1, centre + delta)
    
    return p, ci_lower, ci_upper

# Calculate per-class metrics
per_class_results = []
for i in range(len(class_list)):
    # For class i, calculate TP, FP, TN, FN
    tp = cm[i, i]  # True positives
    fp = np.sum(cm[:, i]) - tp  # False positives (predicted as class i but actually other classes)
    fn = np.sum(cm[i, :]) - tp  # False negatives (actually class i but predicted as other classes)
    tn = np.sum(cm) - tp - fp - fn  # True negatives
    
    # Calculate basic metrics
    precision_val = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall_val = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity_val = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1_val = 2 * precision_val * recall_val / (precision_val + recall_val) if (precision_val + recall_val) > 0 else 0
    
    # Binary accuracy for this class vs all others
    binary_accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0
    
    # Calculate confidence intervals using Wilson method
    precision_ci = calculate_wilson_ci(tp, tp + fp)
    recall_ci = calculate_wilson_ci(tp, tp + fn)
    specificity_ci = calculate_wilson_ci(tn, tn + fp)
    accuracy_ci = calculate_wilson_ci(tp + tn, tp + fp + tn + fn)
    
    # F1 CI using bootstrap approximation
    if f1_val > 0 and (tp + fn) > 0:
        se_f1 = np.sqrt(f1_val * (1 - f1_val) / (tp + fn))
        z = stats.norm.ppf(0.975)
        f1_ci_lower = max(0, f1_val - z * se_f1)
        f1_ci_upper = min(1, f1_val + z * se_f1)
    else:
        f1_ci_lower = f1_ci_upper = 0
    
    per_class_results.append({
        'Class': class_list[i],
        'TP': int(tp),
        'FP': int(fp),
        'TN': int(tn),
        'FN': int(fn),
        'Support': int(support[i]),
        'Accuracy': binary_accuracy,
        'Accuracy_CI_Lower': accuracy_ci[1],
        'Accuracy_CI_Upper': accuracy_ci[2],
        'Precision': precision_val,
        'Precision_CI_Lower': precision_ci[1],
        'Precision_CI_Upper': precision_ci[2],
        'Recall': recall_val,
        'Recall_CI_Lower': recall_ci[1],
        'Recall_CI_Upper': recall_ci[2],
        'Specificity': specificity_val,
        'Specificity_CI_Lower': specificity_ci[1],
        'Specificity_CI_Upper': specificity_ci[2],
        'F1_Score': f1_val,
        'F1_Score_CI_Lower': f1_ci_lower,
        'F1_Score_CI_Upper': f1_ci_upper
    })

# Create DataFrame
results_df = pd.DataFrame(per_class_results)

# Calculate AUC for multiclass (One-vs-Rest approach)
n_classes = len(class_list)
roc_auc = {}
fpr = {}
tpr = {}

# Calculate ROC curve and AUC for each class
for i in range(n_classes):
    # Create binary labels for class i vs rest
    y_binary = (test_labels == i).astype(int)
    y_scores = test_probabilities[:, i]
    
    # Calculate ROC curve
    fpr[i], tpr[i], _ = roc_curve(y_binary, y_scores)
    roc_auc[i] = auc(fpr[i], tpr[i])

# Calculate macro-average ROC
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Calculate micro-average ROC
y_test_binary = label_binarize(test_labels, classes=range(n_classes))
if n_classes == 2:
    y_test_binary = np.hstack((1-y_test_binary, y_test_binary))
fpr["micro"], tpr["micro"], _ = roc_curve(y_test_binary.ravel(), test_probabilities.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Add AUC to results
for i, row in results_df.iterrows():
    results_df.loc[i, 'AUC'] = roc_auc[i]

print(f"\nMacro AUC: {roc_auc['macro']:.4f}, Micro AUC: {roc_auc['micro']:.4f}")

print("\nüìã Per-Class Performance with 95% Confidence Intervals:")
print("-"*160)
print(f"{'Class':>6} | {'Acc':>8} {'[CI]':>12} | {'Prec':>8} {'[CI]':>12} | {'Recall':>8} {'[CI]':>12} | {'Spec':>8} {'[CI]':>12} | {'F1':>8} {'[CI]':>12} | {'AUC':>6} | {'Supp':>4}")
print("-"*160)

for _, row in results_df.iterrows():
    print(f"{row['Class']:>6} | "
          f"{row['Accuracy']:.3f} "
          f"(¬±{row['Accuracy_CI_Upper']-row['Accuracy']:.3f}) | "
          f"{row['Precision']:.3f} "
          f"(¬±{row['Precision_CI_Upper']-row['Precision']:.3f}) | "
          f"{row['Recall']:.3f} "
          f"(¬±{row['Recall_CI_Upper']-row['Recall']:.3f}) | "
          f"{row['Specificity']:.3f} "
          f"(¬±{row['Specificity_CI_Upper']-row['Specificity']:.3f}) | "
          f"{row['F1_Score']:.3f} "
          f"(¬±{row['F1_Score_CI_Upper']-row['F1_Score']:.3f}) | "
          f"{row['AUC']:.3f} | "
          f"{row['Support']:>4}")

# Add separator line
print("-"*160)

# Calculate average metrics
avg_accuracy = results_df['Accuracy'].mean()
avg_specificity = results_df['Specificity'].mean()

# Calculate weighted specificity
weighted_specificity = np.average(results_df['Specificity'], weights=support)

# Add Macro Average row
print(f"{'Macro':>6} | "
      f"{avg_accuracy:.3f} "
      f"{'':>12} | "
      f"{macro_precision:.3f} "
      f"{'':>12} | "
      f"{macro_recall:.3f} "
      f"{'':>12} | "
      f"{avg_specificity:.3f} "
      f"{'':>12} | "
      f"{macro_f1:.3f} "
      f"{'':>12} | "
      f"{roc_auc['macro']:.3f} | "
      f"{sum(support):>4}")

# Add Weighted Average row
print(f"{'Weight':>6} | "
      f"{accuracy:.3f} "
      f"{'':>12} | "
      f"{weighted_precision:.3f} "
      f"{'':>12} | "
      f"{weighted_recall:.3f} "
      f"{'':>12} | "
      f"{weighted_specificity:.3f} "
      f"{'':>12} | "
      f"{weighted_f1:.3f} "
      f"{'':>12} | "
      f"{roc_auc['micro']:.3f} | "
      f"{sum(support):>4}")

# Validation: Check if metrics make sense
print("\nüîç Validation Check:")
print("-"*50)
for _, row in results_df.iterrows():
    # Check if F1 is reasonable given precision and recall
    expected_f1 = 2 * row['Precision'] * row['Recall'] / (row['Precision'] + row['Recall']) if (row['Precision'] + row['Recall']) > 0 else 0
    f1_diff = abs(row['F1_Score'] - expected_f1)
    
    print(f"{row['Class']:>6}: TP={row['TP']:>3}, FP={row['FP']:>3}, TN={row['TN']:>3}, FN={row['FN']:>3}")
    print(f"        Expected F1: {expected_f1:.4f}, Actual F1: {row['F1_Score']:.4f}, Diff: {f1_diff:.4f}")
    
    # Check sklearn vs manual calculation
    sklearn_precision = precision[_] if _ < len(precision) else 0
    sklearn_recall = recall[_] if _ < len(recall) else 0
    sklearn_f1 = f1[_] if _ < len(f1) else 0
    
    print(f"        Sklearn - Prec: {sklearn_precision:.4f}, Recall: {sklearn_recall:.4f}, F1: {sklearn_f1:.4f}")

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Subplot 1: Performance metrics bar chart
ax1 = axes[0]
x_pos = np.arange(len(class_list))
width = 0.15

metrics = ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1_Score']
colors = ['skyblue', 'orange', 'lightgreen', 'salmon', 'gold']

for i, (metric, color) in enumerate(zip(metrics, colors)):
    values = results_df[metric].values
    ci_lower = results_df[f'{metric}_CI_Lower'].values
    ci_upper = results_df[f'{metric}_CI_Upper'].values
    errors = [values - ci_lower, ci_upper - values]
    
    bars = ax1.bar(x_pos + i*width - 2*width, values, width, 
                   yerr=errors, capsize=3, label=metric, 
                   alpha=0.8, color=color)

ax1.set_xlabel('Classes')
ax1.set_ylabel('Score')
ax1.set_title('Per-Class Performance Metrics with 95% CI')
ax1.set_xticks(x_pos)
ax1.set_xticklabels(class_list)
ax1.legend()
ax1.set_ylim(0, 1.1)
ax1.grid(axis='y', alpha=0.3)

# Subplot 2: Confusion Matrix
ax2 = axes[1]
im = ax2.imshow(cm, cmap='Blues', interpolation='nearest')
ax2.set_title('Confusion Matrix')
ax2.set_xlabel('Predicted')
ax2.set_ylabel('Actual')

# Add colorbar
cbar = plt.colorbar(im, ax=ax2)

# Add text annotations
for i in range(len(class_list)):
    for j in range(len(class_list)):
        text = ax2.text(j, i, str(cm[i, j]),
                       ha="center", va="center", 
                       color="white" if cm[i, j] > cm.max()/2 else "black")

ax2.set_xticks(range(len(class_list)))
ax2.set_yticks(range(len(class_list)))
ax2.set_xticklabels(class_list)
ax2.set_yticklabels(class_list)

plt.tight_layout()
plt.show()

# Summary for paper
print("\nüìù Summary for Paper:")
print("="*50)
print(f"The breast cancer classification model achieved an overall accuracy of {accuracy:.4f} on the test set.")
print(f"The macro-averaged F1-score was {macro_f1:.4f}, indicating {['poor', 'fair', 'good', 'excellent'][min(3, int(macro_f1*4))]} performance across all classes.")
print(f"The macro-averaged AUC was {roc_auc['macro']:.4f}, demonstrating {['poor', 'fair', 'good', 'excellent'][min(3, int(roc_auc['macro']*4))]} discriminative ability.")

# Individual class performance
best_f1_idx = results_df['F1_Score'].idxmax()
worst_f1_idx = results_df['F1_Score'].idxmin()
print(f"Best performing class: {results_df.loc[best_f1_idx, 'Class']} (F1-Score: {results_df.loc[best_f1_idx, 'F1_Score']:.4f}, AUC: {results_df.loc[best_f1_idx, 'AUC']:.4f})")
print(f"Most challenging class: {results_df.loc[worst_f1_idx, 'Class']} (F1-Score: {results_df.loc[worst_f1_idx, 'F1_Score']:.4f}, AUC: {results_df.loc[worst_f1_idx, 'AUC']:.4f})")

# Export results
results_summary = {
    'Metric': ['Accuracy', 'Macro Precision', 'Macro Recall', 'Macro F1-Score', 
               'Weighted Precision', 'Weighted Recall', 'Weighted F1-Score', 
               'Macro AUC', 'Micro AUC'],
    'Value': [accuracy, macro_precision, macro_recall, macro_f1, 
              weighted_precision, weighted_recall, weighted_f1, 
              roc_auc['macro'], roc_auc['micro']]
}

summary_df = pd.DataFrame(results_summary)
results_df.to_csv(f'{model_path}detailed_per_class_results.csv', index=False)
summary_df.to_csv(f'{model_path}overall_results.csv', index=False)
cm_df.to_csv(f'{model_path}confusion_matrix.csv')
