In [1]:
# %% Cell 1: Enhanced Imports and Setup
import os
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image
from transformers import ViTConfig
from sklearn.model_selection import StratifiedShuffleSplit
from imblearn.over_sampling import RandomOverSampler
import matplotlib.pyplot as plt
from sklearn.metrics import (roc_curve, auc, confusion_matrix, 
                            ConfusionMatrixDisplay, classification_report)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [2]:
# %% Cell 2: Enhanced Dataset Class with Oversampling
class DyslexiaDataset(Dataset):
    def __init__(self, root_dir, sample_size=7000, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # Collect PNG files
        png_files = []
        for foldername, _, filenames in os.walk(root_dir):
            png_files.extend([
                os.path.join(foldername, f) 
                for f in filenames if f.lower().endswith('.png')
            ])
        
        if not png_files:
            raise ValueError(f"No PNG files found in {root_dir}")

        # Random sampling
        sample_size = min(sample_size, len(png_files))
        np.random.seed(42)
        self.samples = np.random.choice(png_files, sample_size, False)
        
        # Create labels
        self.labels = np.array([
            1 if re.search(r'hsf_[7-9]', path) else 0 
            for path in self.samples
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

In [3]:
def create_transforms():
    from torchvision import transforms as T  # Rename the import
    
    return {
        'train': T.Compose([
            T.RandomRotation(15),
            T.RandomAffine(0, shear=15),
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.2, contrast=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

In [4]:
# %% Cell 4: Feature Extractor
class EnhancedFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((14, 14))
        )
        
    def forward(self, x):
        x = self.cnn(x)  # (B, 256, 14, 14)
        return x.flatten(2).transpose(1, 2)  # (B, 196, 256)

In [5]:
# %% Cell 5: Enhanced Transformer Model
class DyslexiaTransformer(nn.Module):
    def __init__(self, num_features=256, num_classes=2):
        super().__init__()
        self.config = ViTConfig(
            hidden_size=512,
            num_hidden_layers=6,
            num_attention_heads=16,
            intermediate_size=1024,
            hidden_dropout_prob=0.2
        )
        
        # Projection layer
        self.proj = nn.Linear(num_features, self.config.hidden_size)
        
        # Positional embeddings
        self.pos_embed = nn.Parameter(
            torch.zeros(1, 197, self.config.hidden_size))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
        
        # Transformer encoder
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.config.hidden_size,
                nhead=self.config.num_attention_heads,
                dim_feedforward=self.config.intermediate_size,
                dropout=self.config.hidden_dropout_prob,
                batch_first=True
            ),
            num_layers=self.config.num_hidden_layers
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.config.hidden_size),
            nn.Linear(self.config.hidden_size, num_classes)
        )

    def forward(self, x):
        # Project features
        x = self.proj(x)  # (B, 196, 512)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add positional embeddings
        x += self.pos_embed
        
        # Transformer encoder
        x = self.encoder(x)
        
        # Classify using CLS token
        return self.classifier(x[:, 0])

In [6]:
# %% Cell 6: Focal Loss Implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = self.alpha * (1 - pt)**self.gamma * ce_loss
        return loss.mean()

In [7]:
# %% Cell 7: Training Infrastructure
class Trainer:
    def __init__(self, model, extractor, device):
        self.model = model
        self.extractor = extractor
        self.device = device
        
        # Metrics tracking
        self.train_loss = []
        self.val_metrics = {
            'accuracy': [], 'roc_auc': [], 
            'precision': [], 'recall': []}

    def train_epoch(self, loader, optimizer, scheduler):
        self.model.train()
        total_loss = 0
        
        for inputs, labels in loader:
            inputs = inputs.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            
            # Forward pass
            features = self.extractor(inputs)
            outputs = self.model(features)
            
            # Loss calculation
            loss = self.criterion(outputs, labels)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        return total_loss / len(loader)

    def evaluate(self, loader):
        self.model.eval()
        all_preds = []
        all_probs = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in loader:
                inputs = inputs.to(self.device)
                labels = labels.cpu().numpy()
                
                features = self.extractor(inputs)
                outputs = self.model(features)
                probs = F.softmax(outputs, dim=1).cpu().numpy()
                
                all_probs.extend(probs[:, 1])
                all_preds.extend(np.argmax(probs, axis=1))
                all_labels.extend(labels)
        
        # Calculate metrics
        accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        roc_auc = auc(fpr, tpr)
        tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        
        return {
            'accuracy': accuracy,
            'roc_auc': roc_auc,
            'precision': precision,
            'recall': recall
        }

In [8]:
# %% Cell 8: Main Execution (Corrected)
if __name__ == "__main__":
    # Data preparation
    data_transforms = create_transforms()  # Changed variable name
    full_dataset = DyslexiaDataset(
        root_dir='D:\\Dyslexia\\by_class\\by_class',
        sample_size=100,
        transform=None
    )
    
    # Stratified split
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_idx, test_idx = next(sss.split(full_dataset.samples, full_dataset.labels))
    
    # Apply transforms
    class TransformSubset(Dataset):
        def __init__(self, subset, transform):
            self.subset = subset
            self.transform = transform

        def __getitem__(self, idx):
            img, label = self.subset[idx]
            if self.transform:
                img = self.transform(img)
            return img, label

        def __len__(self):
            return len(self.subset)
    
    # Oversampling
    ros = RandomOverSampler(random_state=42)
    resampled_idx, _ = ros.fit_resample(
        train_idx.reshape(-1, 1), full_dataset.labels[train_idx])
    train_idx = resampled_idx.squeeze()
    
    # Create datasets
    train_dataset = TransformSubset(
        Subset(full_dataset, train_idx), data_transforms['train'])  # Use new name
    
    test_dataset = TransformSubset(
        Subset(full_dataset, test_idx), data_transforms['test'])  # Use new name

    # Create data loaders (CRUCIAL FIX)
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Initialize models
    feature_extractor = EnhancedFeatureExtractor().to(device)
    model = DyslexiaTransformer(num_features=256).to(device)
    
    # Multi-GPU support
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        feature_extractor = nn.DataParallel(feature_extractor)
        model = nn.DataParallel(model)

    # Training setup
    optimizer = torch.optim.AdamW([
        {'params': feature_extractor.parameters(), 'lr': 1e-4},
        {'params': model.parameters(), 'lr': 5e-5}
    ], weight_decay=0.01)
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-3,
        steps_per_epoch=len(train_loader),
        epochs=20
    )
    
    criterion = FocalLoss(alpha=0.8, gamma=2.0)
    trainer = Trainer(model, feature_extractor, device)
    trainer.criterion = criterion

    # Training loop
    best_roc = 0
    for epoch in range(20):
        # Training
        loss = trainer.train_epoch(train_loader, optimizer, scheduler)
        trainer.train_loss.append(loss)
        
        # Evaluation
        metrics = trainer.evaluate(test_loader)
        for k in trainer.val_metrics:
            trainer.val_metrics[k].append(metrics[k])
        
        # Save best model
        if metrics['roc_auc'] > best_roc:
            best_roc = metrics['roc_auc']
            torch.save({
                'extractor': feature_extractor.state_dict(),
                'model': model.state_dict(),
            }, 'best_model.pth')
        
        # Print progress
        print(f"Epoch {epoch+1}/20")
        print(f"Train Loss: {loss:.4f} | Val Acc: {metrics['accuracy']:.4f}")
        print(f"ROC AUC: {metrics['roc_auc']:.4f} | Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f}")
        print("-----------------------------------")

    # Final evaluation
    print("\nBest Model Performance:")
    model.load_state_dict(torch.load('best_model.pth')['model'])
    final_metrics = trainer.evaluate(test_loader)
    print(classification_report(
        final_metrics['true_labels'],
        final_metrics['predictions'],
        target_names=['Non-Dyslexic', 'Dyslexic']
    ))

    # Plot metrics
    plt.figure(figsize=(15,5))
    plt.subplot(1,2,1)
    plt.plot(trainer.train_loss, label='Train Loss')
    plt.title('Training Loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(trainer.val_metrics['accuracy'], label='Accuracy')
    plt.plot(trainer.val_metrics['roc_auc'], label='ROC AUC')
    plt.title('Validation Metrics')
    plt.legend()
    plt.show()

RuntimeError: DataLoader worker (pid(s) 20184, 20628, 1532, 24548) exited unexpectedly