In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import json
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

In [11]:
class HeartDataset(Dataset):
    def __init__(self, X, y):
        if isinstance(X, pd.DataFrame):
            X = X.values
        if isinstance(y, pd.Series):
            y = y.values
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)       
    def __len__(self):
        return len(self.X)   
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class DeepHeartNet(nn.Module):
    def __init__(self, input_dim):
        super(DeepHeartNet, self).__init__()       
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.1),            
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.model(x)

class HeartDiseasePredictor:
    def __init__(self, save_path):
        self.save_path = save_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        os.makedirs(save_path, exist_ok=True)
        self.create_directories()
        
    def create_directories(self):
        self.dirs = {
            'models': os.path.join(self.save_path, 'models'),
            'plots': os.path.join(self.save_path, 'plots'),
            'results': os.path.join(self.save_path, 'results')
        }
        for dir_path in self.dirs.values():
            os.makedirs(dir_path, exist_ok=True)
    
    def prepare_data(self, data_path):
        print(f"Loading data from {data_path}")
        data = pd.read_csv(data_path)
        
        print("\nDataset Info:")
        print(data.info())
        print("\nSample Data:")
        print(data.head())
        
        with open(os.path.join(self.dirs['results'], 'data_info.txt'), 'w') as f:
            f.write("Dataset Shape: " + str(data.shape) + "\n")
            f.write("\nColumns: " + str(data.columns.tolist()) + "\n")
            data.info(buf=f)
            f.write("\nDescriptive Statistics:\n")
            data.describe().to_string(buf=f)
        
        return data
    
    def train_model(self, X_train, X_test, y_train, y_test):
        print("Preparing training data...")
        print(f"X_train shape: {X_train.shape}")
        print(f"X_test shape: {X_test.shape}")
        print(f"y_train shape: {y_train.shape}")
        print(f"y_test shape: {y_test.shape}")
        
        train_dataset = HeartDataset(X_train, y_train)
        test_dataset = HeartDataset(X_test, y_test)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32)
        
        input_dim = X_train.shape[1]
        model = DeepHeartNet(input_dim).to(self.device)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        n_epochs = 200
        best_loss = float('inf')
        patience = 20
        no_improve = 0
        history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
        
        print("\nStarting training...")
        for epoch in range(n_epochs):
            model.train()
            train_loss = 0
            for batch_X, batch_y in train_loader:
                batch_X = batch_X.to(self.device)
                batch_y = batch_y.to(self.device).view(-1, 1)
                
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
            
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for batch_X, batch_y in test_loader:
                    batch_X = batch_X.to(self.device)
                    batch_y = batch_y.to(self.device).view(-1, 1)
                    
                    outputs = model(batch_X)
                    loss = criterion(outputs, batch_y)
                    val_loss += loss.item()
                    
                    predicted = (outputs > 0.5).float()
                    total += batch_y.size(0)
                    correct += (predicted == batch_y).sum().item()
            
            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(test_loader)
            accuracy = 100 * correct / total
            
            history['train_loss'].append(avg_train_loss)
            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(accuracy)
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, '
                      f'Val Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.2f}%')
            
            scheduler.step(avg_val_loss)            
            if avg_val_loss < best_loss:
                best_loss = avg_val_loss
                torch.save(model.state_dict(), 
                         os.path.join(self.dirs['models'], 'best_model.pth'))
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= patience:
                    print("Early stopping triggered")
                    break
        
        self._save_training_history(history)
        
        model.load_state_dict(torch.load(os.path.join(self.dirs['models'], 'best_model.pth')))
        self._evaluate_model(model, test_loader, y_test)
        
        return model
    
    def _save_training_history(self, history):
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title('Model Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history['val_acc'], label='Validation Accuracy')
        plt.title('Model Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.dirs['plots'], 'training_history.png'))
        plt.close()
        
        pd.DataFrame(history).to_csv(
            os.path.join(self.dirs['results'], 'training_history.csv'), 
            index=False
        )
    
    def _evaluate_model(self, model, test_loader, y_test):
        model.eval()
        all_preds = []
        all_probs = []
        
        with torch.no_grad():
            for batch_X, _ in test_loader:
                batch_X = batch_X.to(self.device)
                outputs = model(batch_X)
                probs = outputs.cpu().numpy()
                preds = (outputs > 0.5).float().cpu().numpy()
                all_probs.extend(probs)
                all_preds.extend(preds)
        
        all_preds = np.array(all_preds).flatten()
        all_probs = np.array(all_probs).flatten()
        
        report = classification_report(y_test, all_preds)
        with open(os.path.join(self.dirs['results'], 'classification_report.txt'), 'w') as f:
            f.write(report)
        
        cm = confusion_matrix(y_test, all_preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.savefig(os.path.join(self.dirs['plots'], 'confusion_matrix.png'))
        plt.close()
        
        fpr, tpr, _ = roc_curve(y_test, all_probs)
        roc_auc = auc(fpr, tpr)
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, 
                label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC)')
        plt.legend(loc="lower right")
        plt.savefig(os.path.join(self.dirs['plots'], 'roc_curve.png'))
        plt.close()

In [12]:
def main():
    save_path = '/project/zhiwei/hf78/ecg/output/heart'    
    predictor = HeartDiseasePredictor(save_path)
    
    data = predictor.prepare_data('/project/zhiwei/hf78/ecg/data/heart/heart.csv')    
    X = data.drop('target', axis=1) 
    y = data['target']    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    predictor.train_model(X_train_scaled, X_test_scaled, y_train, y_test)
    
    print(f"\nAll results have been saved to {save_path}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading data from /project/zhiwei/hf78/ecg/data/heart/heart.csv

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1025 entries, 0 to 1024
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       1025 non-null   int64  
 1   sex       1025 non-null   int64  
 2   cp        1025 non-null   int64  
 3   trestbps  1025 non-null   int64  
 4   chol      1025 non-null   int64  
 5   fbs       1025 non-null   int64  
 6   restecg   1025 non-null   int64  
 7   thalach   1025 non-null   int64  
 8   exang     1025 non-null   int64  
 9   oldpeak   1025 non-null   float64
 10  slope     1025 non-null   int64  
 11  ca        1025 non-null   int64  
 12  thal      1025 non-null   int64  
 13  target    1025 non-null   int64  
dtypes: float64(1), int64(13)
memory usage: 112.2 KB
None

Sample Data:
   age  sex  cp  trestbps  chol  fbs  restecg  thalach  exang  oldpeak  slope  \
0   52    1   0   