In [17]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import timm
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler

In [2]:
class BrainExpert(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.backbone = timm.create_model('efficientnet_b2', pretrained=True, num_classes=0, global_pool='avg')
        self.fc = nn.Sequential(
            nn.Linear(1408, 512),
            nn.BatchNorm1d(512),
            nn.Hardswish(),
            nn.Dropout(p=0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.fc(features)

In [3]:
class LungExpert(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.backbone = models.densenet121(weights='IMAGENET1K_V1')
        num_ftrs = self.backbone.classifier.in_features
        self.backbone.classifier = nn.Identity()
        self.head = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

In [4]:
class SkinExpert(nn.Module):
    def __init__(self, num_classes=9):
        super().__init__()
        self.backbone = models.resnet50(weights='IMAGENET1K_V1')
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=0.45),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

In [5]:
class ECGExpert(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.backbone = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0)
        self.head = nn.Sequential(
            nn.Linear(1280, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

In [6]:
class ModalityRouter(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet34(weights='IMAGENET1K_V1')
        self.model.fc = nn.Linear(self.model.fc.in_features, 4)

    def forward(self, x):
        return self.model(x)

In [25]:
class ClinicalAIDiagnosticSystem(nn.Module):
    def __init__(self, class_counts):
        super().__init__()
        self.router = ModalityRouter()
        self.experts = nn.ModuleDict({
            'brain': BrainExpert(class_counts['brain']),
            'lung': LungExpert(class_counts['lung']),
            'skin': SkinExpert(class_counts['skin']),
            'ecg': ECGExpert(class_counts['ecg'])
        })
        self.labels = {0: 'brain', 1: 'lung', 2: 'skin', 3: 'ecg'}

    def perform_inference(self, x, samples=25, threshold=0.15):
        # 1. Router Inference
        self.router.eval()
        with torch.no_grad():
            modality_idx = torch.argmax(self.router(x), dim=1).item()
            modality = self.labels[modality_idx]
        
        # 2. Expert Setup
        expert = self.experts[modality]
        
        # --- FIX START ---
        expert.eval() # Set everything to eval (especially BatchNorm)
        
        # Force ONLY Dropout layers to be in train mode for MC Dropout
        for m in expert.modules():
            if isinstance(m, nn.Dropout):
                m.train()
        # --- FIX END ---
        
        stochastic_preds = []
        with torch.no_grad():
            for _ in range(samples):
                stochastic_preds.append(F.softmax(expert(x), dim=1))
        
        preds_tensor = torch.stack(stochastic_preds)
        mean_prediction = preds_tensor.mean(dim=0)
        uncertainty = preds_tensor.std(dim=0).mean().item()
        
        conf, class_idx = torch.max(mean_prediction, dim=1)
        
        if uncertainty > threshold:
            return {"status": "REJECTED", "reason": "High Uncertainty", "score": uncertainty, "modality": modality}
        
        return {
            "status": "ACCEPTED",
            "modality": modality,
            "diagnosis": class_idx.item(),
            "confidence": conf.item(),
            "uncertainty": uncertainty
        }

In [8]:
def get_train_params(model):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    return optimizer, criterion

In [9]:
def train_expert_full(model, train_loader, val_loader, device, epochs=30):
    optimizer, criterion = get_train_params(model)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=epochs
    )
    
    model.to(device)
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
        model.eval()
        correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                correct += (outputs.argmax(1) == labels).sum().item()
        
        val_acc = correct / len(val_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {running_loss/len(train_loader):.4f} | Acc: {val_acc:.4f}")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), f'best_{type(model).__name__}.pth')

In [10]:
def get_dataloaders(root_path, batch_size=32):
    norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        norm
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm
    ])

    train_ds = datasets.ImageFolder(root=f"{root_path}/train", transform=train_transform)
    val_ds = datasets.ImageFolder(root=f"{root_path}/val", transform=val_transform)
    targets = train_ds.targets
    class_count = np.unique(targets, return_counts=True)[1]
    weight = 1. / class_count
    samples_weight = torch.from_numpy(weight[targets])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader


In [12]:
class MedicalDataOrchestrator:
    def __init__(self, base_path):
        self.base_path = base_path
        self.paths = {
            'brain': {
                'train': 'masoudnickparvar/brain-tumor-mri-dataset/Training',
                'val': 'masoudnickparvar/brain-tumor-mri-dataset/Testing'
            },
            'lung': {
                'train': 'omkarmanohardalvi/lungs-disease-dataset-4-types/Lung Disease Dataset/train',
                'val': 'omkarmanohardalvi/lungs-disease-dataset-4-types/Lung Disease Dataset/val',
                'test': 'omkarmanohardalvi/lungs-disease-dataset-4-types/Lung Disease Dataset/test'
            },
            'skin': {
                'train': 'riyaelizashaju/skin-disease-classification-image-dataset/Split_smol/train',
                'val': 'riyaelizashaju/skin-disease-classification-image-dataset/Split_smol/val'
            },
            'ecg': {
                'train': 'evilspirit05/ecg-analysis/ecg_data_new_version/ecg data new version'
        }}
        self.norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        self.train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            self.norm
        ])
        self.val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            self.norm
        ])

    def _get_full_path(self, modality, split):
        return os.path.join(self.base_path, self.paths[modality][split])

    def get_expert_loader(self, modality, batch_size=32):
        train_path = self._get_full_path(modality, 'train')
        if modality == 'ecg':
            full_ds = datasets.ImageFolder(train_path, transform=self.train_transform)
            train_size = int(0.8 * len(full_ds))
            val_size = len(full_ds) - train_size
            train_ds, val_ds = torch.utils.data.random_split(full_ds, [train_size, val_size])
            val_ds.dataset.transform = self.val_transform
        else:
            val_path = self._get_full_path(modality, 'val')
            train_ds = datasets.ImageFolder(train_path, transform=self.train_transform)
            val_ds = datasets.ImageFolder(val_path, transform=self.val_transform)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
        
        print(f"Expert [{modality.upper()}] | Classes: {len(train_ds.dataset.classes if modality=='ecg' else train_ds.classes)}")
        return train_loader, val_loader

    def get_router_loaders(self, batch_size=64):
        all_train_sets = []
        
        for idx, modality in enumerate(['brain', 'lung', 'skin', 'ecg']):
            path = self._get_full_path(modality, 'train')
            ds = datasets.ImageFolder(path, transform=self.train_transform)
            ds.samples = [(s[0], idx) for s in ds.samples]
            ds.targets = [idx] * len(ds)
            all_train_sets.append(ds)

        combined_ds = ConcatDataset(all_train_sets)
        train_len = int(0.9 * len(combined_ds))
        val_len = len(combined_ds) - train_len
        router_train, router_val = torch.utils.data.random_split(combined_ds, [train_len, val_len])
        
        train_loader = DataLoader(router_train, batch_size=batch_size, shuffle=True, num_workers=4)
        val_loader = DataLoader(router_val, batch_size=batch_size, shuffle=False, num_workers=4)
        
        print(f"Router Dataset | Total Images: {len(combined_ds)}")
        return train_loader, val_loader

In [14]:
import os

In [19]:
if __name__ == "__main__":
    DATA_ROOT = "/kaggle/input/datasets"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    configs = {'brain': 4, 'lung': 5, 'skin': 9, 'ecg': 4}
    
    print("ðŸš€ Initializing Medical AI Orchestrator...")
    orchestrator = MedicalDataOrchestrator(DATA_ROOT)
    system = ClinicalAIDiagnosticSystem(configs).to(device)
    print("\n--- Phase 1: Training Modality Router ---")
    r_train, r_val = orchestrator.get_router_loaders(batch_size=64)
    train_expert_full(system.router, r_train, r_val, device, epochs=5)
    for modality in ['brain', 'lung', 'skin', 'ecg']:
        print(f"\n--- Phase 2: Training {modality.upper()} Expert ---")
        try:
            t_loader, v_loader = orchestrator.get_expert_loader(modality,batch_size=32)
            expert_model = system.experts[modality]
            train_expert_full(expert_model, t_loader, v_loader, device, epochs=30)
        except Exception as e:
            print(f"Error training {modality}: {e}")

    print("\nðŸ“¦ Saving full clinical system...")
    save_path = "medical_ai_system_final.pth"
    torch.save(system.state_dict(), save_path)
    print(f"Successfully saved to {save_path}")

ðŸš€ Initializing Medical AI Orchestrator...

--- Phase 1: Training Modality Router ---
Router Dataset | Total Images: 13391
Epoch 1/5 | Loss: 0.3787 | Acc: 0.9836
Epoch 2/5 | Loss: 0.3613 | Acc: 0.9993
Epoch 3/5 | Loss: 0.3591 | Acc: 1.0000
Epoch 4/5 | Loss: 0.3518 | Acc: 1.0000
Epoch 5/5 | Loss: 0.3501 | Acc: 1.0000

--- Phase 2: Training BRAIN Expert ---
Expert [BRAIN] | Classes: 4
Epoch 1/30 | Loss: 0.6576 | Acc: 0.9252
Epoch 2/30 | Loss: 0.4614 | Acc: 0.9794
Epoch 3/30 | Loss: 0.4282 | Acc: 0.9825
Epoch 4/30 | Loss: 0.4232 | Acc: 0.9886
Epoch 5/30 | Loss: 0.4108 | Acc: 0.9680
Epoch 6/30 | Loss: 0.4156 | Acc: 0.9687
Epoch 7/30 | Loss: 0.4000 | Acc: 0.9748
Epoch 8/30 | Loss: 0.4115 | Acc: 0.9649
Epoch 11/30 | Loss: 0.3921 | Acc: 0.9924
Epoch 12/30 | Loss: 0.3821 | Acc: 0.9916
Epoch 13/30 | Loss: 0.3784 | Acc: 0.9878
Epoch 14/30 | Loss: 0.3756 | Acc: 0.9931
Epoch 15/30 | Loss: 0.3720 | Acc: 0.9870
Epoch 16/30 | Loss: 0.3719 | Acc: 0.9977
Epoch 17/30 | Loss: 0.3652 | Acc: 0.9939
Epoch

In [20]:
from PIL import Image

In [23]:
def predict_external_image(image_path, system, device):
    # Mapping for readability
    class_map = {
        'brain': ['glioma', 'meningioma', 'notumor', 'pituitary'],
        'lung': ['Bacterial Pneumonia', 'Corona Virus Disease', 'Normal', 'Tuberculosis', 'Viral Pneumonia'],
        'ecg': ['Abnormal', 'Infarction', 'Normal', 'History of MI'],
        'skin': ['Actinic keratosis', 'Atopic Dermatitis', 'Benign keratosis', 'Dermatofibroma', 
                 'Melanocytic nevus', 'Melanoma', 'Squamous cell carcinoma', 'Tinea Ringworm', 'Vascular lesion']
    }

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    img = Image.open(image_path).convert('RGB')
    x = transform(img).unsqueeze(0).to(device)
    
    res = system.perform_inference(x)
    
    print(f"\nMODALITY: {res['modality'].upper()}")
    if res['status'] == "ACCEPTED":
        diag_label = class_map[res['modality']][res['diagnosis']]
        print(f"DIAGNOSIS: {diag_label}")
        print(f"CONFIDENCE: {res['confidence']*100:.2f}%")
    else:
        print(f"STATUS: {res['status']} ({res['reason']})")

In [26]:
system = ClinicalAIDiagnosticSystem(configs).to(device)
system.load_state_dict(torch.load("medical_ai_system_final.pth"))
predict_external_image("/kaggle/input/datasets/evilspirit05/ecg-analysis/ecg_data_new_version/ecg data new version/abnormal_heartbeat_ecg_images/HB(10).jpg", system, device)


MODALITY: ECG
DIAGNOSIS: Abnormal
CONFIDENCE: 89.81%
