In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
import matplotlib.pyplot as plt
import sklearn.metrics as metrics

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [3]:
from torch.utils.data import Dataset, DataLoader

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
from sklearn.preprocessing import StandardScaler, LabelEncoder

In [6]:
data_ori = pd.read_csv('E:/RESEARCH/Datasets/dissertation/heart_disease.csv') ## heart disease public dataset

In [7]:
print(f"The shape of the original dataset is: {data_ori.shape}")
data_ori.head()

The shape of the original dataset is: (1025, 14)


Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,52,1,0,125,212,0,1,168,0,1.0,2,2,3,0
1,53,1,0,140,203,1,0,155,1,3.1,0,0,3,0
2,70,1,0,145,174,0,1,125,1,2.6,0,0,3,0
3,61,1,0,148,203,0,1,161,0,0.0,2,1,3,0
4,62,0,0,138,294,1,1,106,0,1.9,1,3,2,0


In [8]:
data_ori.columns

Index(['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach',
       'exang', 'oldpeak', 'slope', 'ca', 'thal', 'target'],
      dtype='object')

In [9]:
data_B = data_ori[data_ori['target']==0]
data_M = data_ori[data_ori['target']==1]

In [10]:
sample_B = data_B.sample(n=260, random_state=710674)
sample_M = data_M.sample(n=526, random_state=710674)

# 두 샘플을 합쳐서 최종 데이터프레임 구성
data_sampled = pd.concat([sample_B, sample_M])

In [12]:
data_att = data_sampled.copy()

In [14]:
data_att.groupby('target')['sex'].value_counts()

target  sex
0       1      214
        0       46
1       1      300
        0      226
Name: sex, dtype: int64

In [15]:
# 최소값과 최대값을 확인
min_age = data_att['age'].min()
max_age = data_att['age'].max()

# 연령 구간을 10살씩 끊어서 정의
age_bins = list(range(min_age // 10 * 10, max_age + 10, 10))  # 10살 간격으로 구간 정의
age_labels = [f"({i}, {i+10}]" for i in age_bins[:-1]]  # 구간 레이블 생성

# pd.cut()을 사용하여 연령을 구간별로 그룹화
data_att['age_group'] = pd.cut(data_att['age'], bins=age_bins, labels=age_labels, right=False)

In [16]:
data_att.age_group.value_counts()

(50, 60]    320
(40, 50]    198
(60, 70]    189
(30, 40]     48
(70, 80]     27
(20, 30]      4
Name: age_group, dtype: int64

In [17]:
data_att['age_group_enc'] = pd.Categorical(data_att['age_group']).codes

In [19]:
data_att = data_att.drop(['age', 'age_group'], axis=1)

In [20]:
data_att.head()

Unnamed: 0,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target,age_group_enc
878,1,0,120,188,0,1,113,0,1.4,1,1,3,0,3
370,0,0,132,341,1,0,136,1,3.0,1,0,3,0,2
349,0,2,130,263,0,1,97,0,1.2,1,1,3,0,4
600,0,2,130,263,0,1,97,0,1.2,1,1,3,0,4
429,1,2,108,243,0,1,152,0,0.0,2,0,2,0,2


In [21]:
x = data_att.drop(['target', 'age_group_enc', 'sex'], axis=1)
x = x.fillna(x.mean()) ## filling na values with mean values (just drop the rows is also a possible option)
y = data_att.target
# c = data_cvae.loc[:,['age_group', 'sex']]

In [23]:
class MedicalDataset(Dataset):
    def __init__(self, dataframe, vital_signs_cols, condition_cols):
        """
        Parameters:
        -----------
        dataframe : pandas.DataFrame
            입력 데이터프레임
        vital_signs_cols : list
            15개 생체신호 변수명 리스트
        condition_cols : list
            조건 변수명 리스트 (질병, 성별, 연령)
        """
        self.data = dataframe.copy()
        
        # 생체신호 데이터 정규화
        self.scaler = StandardScaler()
        self.data[vital_signs_cols] = self.scaler.fit_transform(self.data[vital_signs_cols])
        
        # 조건 변수 전처리
        self.label_encoders = {}
        for col in condition_cols:
            if self.data[col].dtype == 'object' or self.data[col].dtype == 'category':
                le = LabelEncoder()
                self.data[col] = le.fit_transform(self.data[col])
                self.label_encoders[col] = le
        
        # 연령 정규화 (연령이 numerical 데이터인 경우)
        if 'age' in condition_cols:
            age_scaler = StandardScaler()
            self.data['age'] = age_scaler.fit_transform(self.data[['age']])
        
        # numpy 배열로 변환
        self.vital_signs = self.data[vital_signs_cols].values
        self.conditions = self.data[condition_cols].values
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        vital_signs = torch.FloatTensor(self.vital_signs[idx])
        conditions = torch.FloatTensor(self.conditions[idx])
        return vital_signs, conditions
    
    def get_scalers(self):
        """스케일러와 인코더 반환 (나중에 역변환을 위해 필요)"""
        return {
            'vital_signs_scaler': self.scaler,
            'label_encoders': self.label_encoders
        }

# 데이터 로더 생성 함수
def create_data_loader(data_att, batch_size=32, shuffle=True):
    """
    Parameters:
    -----------
    data_att : pandas.DataFrame
        입력 데이터프레임
    batch_size : int
        배치 크기
    shuffle : bool
        데이터 셔플 여부
    
    Returns:
    --------
    train_loader : DataLoader
        학습용 데이터 로더
    dataset : MedicalDataset
        데이터셋 객체 (스케일러와 인코더 접근용)
    """
    # 생체신호 변수와 조건 변수 컬럼명 지정
    vital_signs_cols = [col for col in data_att.columns if col.startswith('vital_')]  # 예시: vital_1, vital_2, ...
    condition_cols = ['disease', 'gender', 'age']  # 실제 컬럼명에 맞게 수정 필요
    
    # 데이터셋 생성
    dataset = MedicalDataset(data_att, vital_signs_cols, condition_cols)
    
    # 데이터 로더 생성
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=True  # 마지막 배치가 배치 크기보다 작은 경우 제외
    )
    
    return train_loader, dataset

In [24]:
print("데이터프레임 컬럼:", data_att.columns)

# 데이터 로더 생성
train_loader, dataset = create_data_loader(
    data_att=data_att,
    batch_size=32,
    shuffle=True
)

데이터프레임 컬럼: Index(['sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang',
       'oldpeak', 'slope', 'ca', 'thal', 'target', 'age_group_enc'],
      dtype='object')


ValueError: at least one array or dtype is required