In [1]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, accuracy_score, precision_score, recall_score, f1_score

# Create output directories
def create_output_dirs():
    output_dirs = {
        'base': 'output',
        'models': 'output/models',
        'metrics': 'output/metrics',
        'plots': 'output/plots'
    }
    
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    
    return output_dirs

# Metrics tracking class
class MetricsTracker:
    def __init__(self, output_dirs):
        self.output_dirs = output_dirs
        self.history = {
            'epoch': [],
            'train_loss': [], 'val_loss': [],
            'train_auc': [], 'val_auc': [],
            'train_acc': [], 'val_acc': [],
            'train_precision': [], 'val_precision': [],
            'train_recall': [], 'val_recall': [],
            'train_f1': [], 'val_f1': []
        }
        
    def update(self, epoch, metrics):
        self.history['epoch'].append(epoch)
        for key, value in metrics.items():
            self.history[key].append(value)
    
    def save_metrics(self):
        # Save as CSV
        df = pd.DataFrame(self.history)
        df.to_csv(f"{self.output_dirs['metrics']}/training_history.csv", index=False)
        
        # Save as JSON
        with open(f"{self.output_dirs['metrics']}/training_history.json", 'w') as f:
            json.dump(self.history, f, indent=4)
    
    def plot_metrics(self):
        # Plot loss curves
        self._plot_metric('loss', 'Model Loss')
        
        # Plot AUC curves
        self._plot_metric('auc', 'Model AUC')
        
        # Plot accuracy curves
        self._plot_metric('acc', 'Model Accuracy')
        
        # Plot precision-recall curves
        self._plot_metric('precision', 'Model Precision')
        self._plot_metric('recall', 'Model Recall')
        
        plt.close('all')
    
    def _plot_metric(self, metric_name, title):
        plt.figure(figsize=(10, 6))
        plt.plot(self.history['epoch'], self.history[f'train_{metric_name}'], label=f'Train {metric_name}')
        plt.plot(self.history['epoch'], self.history[f'val_{metric_name}'], label=f'Val {metric_name}')
        plt.title(title)
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.capitalize())
        plt.legend()
        plt.grid(True)
        plt.savefig(f"{self.output_dirs['plots']}/{metric_name}_curves.png")
        plt.close()
    
    def plot_confusion_matrix(self, y_true, y_pred, phase='val'):
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - {phase.capitalize()} Set')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(f"{self.output_dirs['plots']}/confusion_matrix_{phase}.png")
        plt.close()
    
    def plot_roc_curve(self, fpr, tpr, auc, phase='val'):
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc:.3f})')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve - {phase.capitalize()} Set')
        plt.legend(loc="lower right")
        plt.savefig(f"{self.output_dirs['plots']}/roc_curve_{phase}.png")
        plt.close()

In [None]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, precision_recall_curve, accuracy_score, precision_score, recall_score, f1_score
import glob
from tqdm import tqdm

def create_output_dirs():
    output_dirs = {
        'base': 'output',
        'models': 'output/models',
        'metrics': 'output/metrics',
        'plots': 'output/plots'
    }
    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    return output_dirs

class MetricsTracker:
    def __init__(self, output_dirs):
        self.output_dirs = output_dirs
        self.history = {
            'epoch': [],
            'train_loss': [], 'val_loss': [],
            'train_auc': [], 'val_auc': [],
            'train_acc': [], 'val_acc': [],
            'train_precision': [], 'val_precision': [],
            'train_recall': [], 'val_recall': [],
            'train_f1': [], 'val_f1': []
        }
    
    def update(self, epoch, metrics):
        self.history['epoch'].append(epoch)
        for key, value in metrics.items():
            self.history[key].append(value)
    
    def save_metrics(self):
        df = pd.DataFrame(self.history)
        df.to_csv(f"{self.output_dirs['metrics']}/training_history.csv", index=False)
        with open(f"{self.output_dirs['metrics']}/training_history.json", 'w') as f:
            json.dump(self.history, f, indent=4)
    
    def plot_metrics(self):
        metrics = ['loss', 'auc', 'acc', 'precision', 'recall']
        for metric in metrics:
            self._plot_metric(metric, f'Model {metric.upper()}')
    
    def _plot_metric(self, metric_name, title):
        plt.figure(figsize=(10, 6))
        plt.plot(self.history['epoch'], self.history[f'train_{metric_name}'], label=f'Train {metric_name}')
        plt.plot(self.history['epoch'], self.history[f'val_{metric_name}'], label=f'Val {metric_name}')
        plt.title(title)
        plt.xlabel('Epoch')
        plt.ylabel(metric_name.capitalize())
        plt.legend()
        plt.grid(True)
        plt.savefig(f"{self.output_dirs['plots']}/{metric_name}_curves.png")
        plt.close()
    
    def plot_confusion_matrix(self, y_true, y_pred, phase='val'):
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - {phase.capitalize()} Set')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(f"{self.output_dirs['plots']}/confusion_matrix_{phase}.png")
        plt.close()
    
    def plot_roc_curve(self, fpr, tpr, auc, phase='val'):
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc:.3f})')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve - {phase.capitalize()} Set')
        plt.legend(loc="lower right")
        plt.savefig(f"{self.output_dirs['plots']}/roc_curve_{phase}.png")
        plt.close()

class BreastHistopathologyDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        
        self.image_paths = []
        self.labels = []
        
        for patient_dir in glob.glob(os.path.join(root_dir, "*/")):
            for label in [0, 1]:
                class_dir = os.path.join(patient_dir, str(label))
                if not os.path.exists(class_dir):
                    continue
                paths = glob.glob(os.path.join(class_dir, "*.png"))
                self.image_paths.extend(paths)
                self.labels.extend([label] * len(paths))
        
        X_train, X_temp, y_train, y_temp = train_test_split(
            self.image_paths, self.labels, 
            test_size=0.3, 
            random_state=42
        )
        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp, 
            test_size=0.5, 
            random_state=42
        )
        
        if split == 'train':
            self.image_paths, self.labels = X_train, y_train
        elif split == 'val':
            self.image_paths, self.labels = X_val, y_val
        else:
            self.image_paths, self.labels = X_test, y_test

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

class ViTForHistopathology(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.vit.heads = nn.Linear(self.vit.hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.vit(x)
        x = self.dropout(x)
        return x

def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda'):
    output_dirs = create_output_dirs()
    metrics_tracker = MetricsTracker(output_dirs)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    
    best_val_auc = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_preds = []
        train_labels = []
        
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
            preds = (probs > 0.5).astype(int)
            train_preds.extend(preds)
            train_labels.extend(labels.cpu().numpy())
        
        train_metrics = {
            'train_loss': train_loss / len(train_loader),
            'train_auc': roc_auc_score(train_labels, train_preds),
            'train_acc': accuracy_score(train_labels, train_preds),
            'train_precision': precision_score(train_labels, train_preds),
            'train_recall': recall_score(train_labels, train_preds),
            'train_f1': f1_score(train_labels, train_preds)
        }
        
        model.eval()
        val_loss = 0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
                preds = (probs > 0.5).astype(int)
                val_preds.extend(preds)
                val_labels.extend(labels.cpu().numpy())
        
        val_metrics = {
            'val_loss': val_loss / len(val_loader),
            'val_auc': roc_auc_score(val_labels, val_preds),
            'val_acc': accuracy_score(val_labels, val_preds),
            'val_precision': precision_score(val_labels, val_preds),
            'val_recall': recall_score(val_labels, val_preds),
            'val_f1': f1_score(val_labels, val_preds)
        }
        
        metrics = {**train_metrics, **val_metrics}
        metrics_tracker.update(epoch, metrics)
        
        if (epoch + 1) % 5 == 0:
            fpr, tpr, _ = roc_curve(val_labels, val_preds)
            metrics_tracker.plot_roc_curve(fpr, tpr, val_metrics['val_auc'])
            metrics_tracker.plot_confusion_matrix(val_labels, val_preds)
        
        if val_metrics['val_auc'] > best_val_auc:
            best_val_auc = val_metrics['val_auc']
            torch.save(model.state_dict(), f"{output_dirs['models']}/ViT_best.pt")
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f"Train - Loss: {metrics['train_loss']:.4f}, AUC: {metrics['train_auc']:.4f}, "
              f"Acc: {metrics['train_acc']:.4f}, F1: {metrics['train_f1']:.4f}")
        print(f"Val - Loss: {metrics['val_loss']:.4f}, AUC: {metrics['val_auc']:.4f}, "
              f"Acc: {metrics['val_acc']:.4f}, F1: {metrics['val_f1']:.4f}")
    
    metrics_tracker.save_metrics()
    metrics_tracker.plot_metrics()
    
    return model, metrics_tracker

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    data_root = "data"
    train_dataset = BreastHistopathologyDataset(data_root, 'train', transform)
    val_dataset = BreastHistopathologyDataset(data_root, 'val', transform)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    model = ViTForHistopathology()
    model = model.to(device)
    
    model, metrics_tracker = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=15,
        device=device
    )

if __name__ == "__main__":
    main()

Using device: cuda


Epoch 1/15: 100%|██████████| 6071/6071 [17:11<00:00,  5.88it/s]


Epoch 1/15:
Train - Loss: 0.2947, AUC: 0.8407, Acc: 0.8787, F1: 0.7791
Val - Loss: 0.2561, AUC: 0.8800, Acc: 0.8908, F1: 0.8174


Epoch 2/15: 100%|██████████| 6071/6071 [17:21<00:00,  5.83it/s]


Epoch 2/15:
Train - Loss: 0.2521, AUC: 0.8679, Acc: 0.8983, F1: 0.8167
Val - Loss: 0.2363, AUC: 0.8724, Acc: 0.9029, F1: 0.8251


Epoch 3/15: 100%|██████████| 6071/6071 [17:25<00:00,  5.81it/s]


Epoch 3/15:
Train - Loss: 0.2331, AUC: 0.8785, Acc: 0.9068, F1: 0.8322
Val - Loss: 0.2301, AUC: 0.8723, Acc: 0.9052, F1: 0.8275


Epoch 4/15: 100%|██████████| 6071/6071 [17:24<00:00,  5.81it/s]


Epoch 4/15:
Train - Loss: 0.2179, AUC: 0.8886, Acc: 0.9140, F1: 0.8458
Val - Loss: 0.2257, AUC: 0.8742, Acc: 0.9080, F1: 0.8317


Epoch 5/15: 100%|██████████| 6071/6071 [17:24<00:00,  5.81it/s]


Epoch 5/15:
Train - Loss: 0.2045, AUC: 0.8951, Acc: 0.9190, F1: 0.8549
Val - Loss: 0.2339, AUC: 0.8671, Acc: 0.9059, F1: 0.8251


Epoch 6/15: 100%|██████████| 6071/6071 [17:22<00:00,  5.82it/s]


Epoch 6/15:
Train - Loss: 0.1899, AUC: 0.9021, Acc: 0.9246, F1: 0.8650
Val - Loss: 0.2120, AUC: 0.8932, Acc: 0.9132, F1: 0.8480


Epoch 7/15: 100%|██████████| 6071/6071 [17:23<00:00,  5.82it/s]


Epoch 7/15:
Train - Loss: 0.1746, AUC: 0.9121, Acc: 0.9319, F1: 0.8784
Val - Loss: 0.2246, AUC: 0.8791, Acc: 0.9110, F1: 0.8380


Epoch 8/15:  35%|███▌      | 2150/6071 [06:09<11:13,  5.82it/s]


In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Set data root
data_root = "data"  # Make sure this points to your data directory

# Create test dataset and loader
test_dataset = BreastHistopathologyDataset(data_root, 'test', transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=16)

# Initialize model
model = ViTForHistopathology(num_classes=2, pretrained=True)
model = model.to(device)

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Test evaluation
test_loss = 0
test_preds = []
test_labels = []
criterion = nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        test_loss += loss.item()
        probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
        test_preds.extend(probs)
        test_labels.extend(labels.cpu().numpy())

test_auc = roc_auc_score(test_labels, test_preds)
print(f'\nTest Results:')
print(f'Test Loss: {test_loss/len(test_loader):.4f}, Test AUC: {test_auc:.4f}')