In [1]:
import sys
import os

# Force immediate output
sys.stdout.flush()
os.environ['PYTHONUNBUFFERED'] = '1'

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, precision_score, recall_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Force immediate output
sys.stdout.flush()
os.environ['PYTHONUNBUFFERED'] = '1'

print("=== STARTING DEBUG VERSION ===")
print("Python version:", sys.version)
print("PyTorch version:", torch.__version__)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Dataset class for loading HHT plot images
class EEGDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        print(f"Dataset created with {len(image_paths)} images")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise e

# Simplified Vision Transformer for debugging
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, n_classes=2):
        super().__init__()
        # Much simpler model for debugging
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )
        print("Simple ViT model created")
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Data loading function with debug prints
def load_data(schizophrenia_folder, healthy_folder):
    """Load image paths and labels from folders"""
    print(f"Loading data from:")
    print(f"  Schizophrenia folder: {schizophrenia_folder}")
    print(f"  Healthy folder: {healthy_folder}")
    
    if not os.path.exists(schizophrenia_folder):
        print(f"ERROR: Schizophrenia folder does not exist: {schizophrenia_folder}")
        return [], []
    
    if not os.path.exists(healthy_folder):
        print(f"ERROR: Healthy folder does not exist: {healthy_folder}")
        return [], []
    
    image_paths = []
    labels = []
    
    # Load schizophrenia images (label = 1)
    schiz_files = [f for f in os.listdir(schizophrenia_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(schiz_files)} schizophrenia images")
    
    for img_name in schiz_files:
        image_paths.append(os.path.join(schizophrenia_folder, img_name))
        labels.append(1)
    
    # Load healthy images (label = 0)
    healthy_files = [f for f in os.listdir(healthy_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(healthy_files)} healthy images")
    
    for img_name in healthy_files:
        image_paths.append(os.path.join(healthy_folder, img_name))
        labels.append(0)
    
    print(f"Total loaded: {len(image_paths)} images")
    return image_paths, labels

# Training function with debug prints
def train_model(model, train_loader, criterion, optimizer, device):
    print("  Starting training phase...")
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    batch_count = 0
    for images, labels in train_loader:
        batch_count += 1
        if batch_count == 1:
            print(f"    Processing batch {batch_count}, batch size: {images.size(0)}")
        
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    print(f"  Training completed: {batch_count} batches processed")
    return epoch_loss, epoch_acc

# Validation function with debug prints
def validate_model(model, val_loader, criterion, device):
    print("  Starting validation phase...")
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    batch_count = 0
    with torch.no_grad():
        for images, labels in val_loader:
            batch_count += 1
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100 * correct / total
    
    # Calculate additional metrics
    precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0) * 100
    recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0) * 100
    f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0) * 100
    
    print(f"  Validation completed: {batch_count} batches processed")
    return epoch_loss, epoch_acc, precision, recall, f1, all_predictions, all_labels

# Main function with extensive debugging
def main():
    print("\n=== MAIN FUNCTION STARTED ===")
    
    # Configure these paths according to your data structure
    SCHIZOPHRENIA_FOLDER = "D:/HHT/S"  # Update this path
    HEALTHY_FOLDER = "D:/HHT/H"  # Update this path
    
    print("Please update the folder paths in the script:")
    print(f"SCHIZOPHRENIA_FOLDER = '{SCHIZOPHRENIA_FOLDER}'")
    print(f"HEALTHY_FOLDER = '{HEALTHY_FOLDER}'")
    
    # For testing, let's create some dummy data if folders don't exist
    if not os.path.exists(SCHIZOPHRENIA_FOLDER) or not os.path.exists(HEALTHY_FOLDER):
        print("\n⚠️  FOLDERS NOT FOUND - CANNOT PROCEED")
        print("Please update the folder paths and run again")
        return
    
    # Hyperparameters
    IMG_SIZE = 224
    BATCH_SIZE = 8  # Smaller batch size for debugging
    LEARNING_RATE = 1e-3
    EPOCHS = 5  # Fewer epochs for debugging
    N_FOLDS = 2  # Fewer folds for debugging
    
    print(f"\nHyperparameters:")
    print(f"  IMG_SIZE: {IMG_SIZE}")
    print(f"  BATCH_SIZE: {BATCH_SIZE}")
    print(f"  LEARNING_RATE: {LEARNING_RATE}")
    print(f"  EPOCHS: {EPOCHS}")
    print(f"  N_FOLDS: {N_FOLDS}")
    
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    print("Data transforms created")
    
    # Load data
    print("\n=== LOADING DATA ===")
    image_paths, labels = load_data(SCHIZOPHRENIA_FOLDER, HEALTHY_FOLDER)
    
    if len(image_paths) == 0:
        print("ERROR: No images loaded. Check your folder paths!")
        return
    
    print(f"Data loaded successfully:")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Schizophrenia: {sum(labels)} images")
    print(f"  Healthy: {len(labels) - sum(labels)} images")
    
    # Convert to numpy arrays
    image_paths = np.array(image_paths)
    labels = np.array(labels)
    
    # Test loading one image
    print("\n=== TESTING IMAGE LOADING ===")
    try:
        test_dataset = EEGDataset([image_paths[0]], [labels[0]], transform=transform)
        test_img, test_label = test_dataset[0]
        print(f"Test image shape: {test_img.shape}, label: {test_label}")
        print("Image loading test passed ✓")
    except Exception as e:
        print(f"ERROR: Image loading test failed: {e}")
        return
    
    # 2-fold cross validation for debugging
    print(f"\n=== STARTING {N_FOLDS}-FOLD CROSS VALIDATION ===")
    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(image_paths, labels)):
        print(f"\n{'='*60}")
        print(f"FOLD {fold + 1}/{N_FOLDS}")
        print(f"{'='*60}")
        
        # Split data
        train_paths, val_paths = image_paths[train_idx], image_paths[val_idx]
        train_labels, val_labels = labels[train_idx], labels[val_idx]
        
        print(f"Train set: {len(train_paths)} images")
        print(f"Val set: {len(val_paths)} images")
        
        # Create datasets
        print("Creating datasets...")
        train_dataset = EEGDataset(train_paths, train_labels, transform=transform)
        val_dataset = EEGDataset(val_paths, val_labels, transform=transform)
        
        # Create data loaders
        print("Creating data loaders...")
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)  # num_workers=0 for debugging
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
        
        print(f"Train loader: {len(train_loader)} batches")
        print(f"Val loader: {len(val_loader)} batches")
        
        # Initialize model
        print("Initializing model...")
        model = SimpleViT(img_size=IMG_SIZE, n_classes=2).to(device)
        
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
        print("Loss function and optimizer created")
        
        # Training loop
        print(f"\nStarting training for {EPOCHS} epochs...")
        print("-" * 80)
        
        for epoch in range(EPOCHS):
            print(f"\nEPOCH {epoch+1}/{EPOCHS} STARTING...")
            
            try:
                # Training phase
                train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device)
                print(f"Training completed - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
                
                # Validation phase  
                val_loss, val_acc, val_precision, val_recall, val_f1, val_predictions, val_true = validate_model(model, val_loader, criterion, device)
                print(f"Validation completed - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
                
                # THIS IS THE MAIN RESULT LINE - SHOULD ALWAYS PRINT
                print(f'✓ EPOCH [{epoch+1:2d}/{EPOCHS}] RESULTS:')
                print(f'  Train: Loss={train_loss:.4f}, Acc={train_acc:6.2f}%')
                print(f'  Val:   Loss={val_loss:.4f}, Acc={val_acc:6.2f}%, Prec={val_precision:6.2f}%, Rec={val_recall:6.2f}%, F1={val_f1:6.2f}%')
                
                # Force output flush
                sys.stdout.flush()
                
            except Exception as e:
                print(f"ERROR in epoch {epoch+1}: {e}")
                import traceback
                traceback.print_exc()
                return
        
        print(f"\nFold {fold + 1} completed successfully!")
        break  # Only do first fold for debugging
    
    print("\n=== DEBUG VERSION COMPLETED ===")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"CRITICAL ERROR: {e}")
        import traceback
        traceback.print_exc()

=== STARTING DEBUG VERSION ===
Python version: 3.11.4 (tags/v3.11.4:d2340ef, Jun  7 2023, 05:45:37) [MSC v.1934 64 bit (AMD64)]
PyTorch version: 2.7.1+cpu
Using device: cpu

=== MAIN FUNCTION STARTED ===
Please update the folder paths in the script:
SCHIZOPHRENIA_FOLDER = 'D:/HHT/S'
HEALTHY_FOLDER = 'D:/HHT/H'

Hyperparameters:
  IMG_SIZE: 224
  BATCH_SIZE: 8
  LEARNING_RATE: 0.001
  EPOCHS: 5
  N_FOLDS: 2
Data transforms created

=== LOADING DATA ===
Loading data from:
  Schizophrenia folder: D:/HHT/S
  Healthy folder: D:/HHT/H
Found 5146 schizophrenia images
Found 4235 healthy images
Total loaded: 9381 images
Data loaded successfully:
  Total images: 9381
  Schizophrenia: 5146 images
  Healthy: 4235 images

=== TESTING IMAGE LOADING ===
Dataset created with 1 images
Test image shape: torch.Size([3, 224, 224]), label: 1
Image loading test passed ✓

=== STARTING 2-FOLD CROSS VALIDATION ===

FOLD 1/2
Train set: 4690 images
Val set: 4691 images
Creating datasets...
Dataset created with 4