## 1D ECGNET training
모델 참고 코드
- https://www.kaggle.com/code/nischaydnk/lightning-1d-eegnet-training-pipeline-hbs/notebook

## Imports

In [21]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, WeightedRandomSampler
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryAUROC
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm import tqdm
from scipy.signal import butter, filtfilt, iirnotch
import pywt
import random

## CFG

In [24]:
class CFG:
    fs = 500
    csv_path = './train_fib.csv'
    wave_dir = './wave'
    output_dir = './output'
    batch_size = 32
    epochs = 30
    lr = 1e-3
    folds = [0, 1, 2, 3, 4]
    seed = 42
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Dataset

In [32]:
def butter_highpass_filter(data, cutoff=0.5, fs=500, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    return filtfilt(b, a, data, axis=0)

def butter_lowpass_filter(data, cutoff=150, fs=500, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return filtfilt(b, a, data, axis=0)

def notch_filter(data, freq=60.0, fs=500.0, Q=30.0):
    b, a = iirnotch(freq / (0.5 * fs), Q)
    return filtfilt(b, a, data, axis=0)


class ECGDataset(Dataset):
    def __init__(self, df, ecg_dir, test=False):
        self.df = df.reset_index(drop=True)
        self.ecg_dir = ecg_dir
        self.test = test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        
        row = self.df.iloc[index]
        path = os.path.join(self.ecg_dir, f"{row.study_id}.npy")
        signal = np.load(path)
        signal = np.nan_to_num(signal)
        signal = np.clip(signal, -1024, 1024) / 32.0

        signal = butter_highpass_filter(signal, cutoff = 0.5, fs=CFG.fs) 
        signal = butter_lowpass_filter(signal, cutoff = 150, fs=CFG.fs)
        signal = notch_filter(signal, freq =60.0, fs=CFG.fs)
        signal = signal.copy()

        x = torch.tensor(signal, dtype=torch.float32).permute(1, 0)  # shape: (12, 5000)
        if self.test:
            return x
        y = torch.tensor(row.report_0, dtype=torch.float32)
        return x, y
        


## Model

In [22]:
class ResNet_1D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super().__init__()
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.0)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.maxpool = nn.MaxPool1d(2)
        self.downsampling = downsampling

    def forward(self, x):
        identity = self.downsampling(x)
        out = self.relu(self.bn1(x))
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.relu(self.bn2(out))
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.maxpool(out)
        out += identity
        return out

class ECGNet(nn.Module):
    def __init__(self, kernels, in_channels=12, fixed_kernel_size=17, num_classes=1):
        super().__init__()
        self.planes = 24
        self.parallel_conv = nn.ModuleList([
            nn.Conv1d(in_channels, self.planes, k, 1, 0, bias=False) for k in kernels
        ])
        self.bn1 = nn.BatchNorm1d(self.planes)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv1d(self.planes, self.planes, fixed_kernel_size, 2, 2, bias=False)
        self.block = self._make_resnet_layer(fixed_kernel_size, 1, 9, fixed_kernel_size // 2)
        self.bn2 = nn.BatchNorm1d(self.planes)
        self.avgpool = nn.AvgPool1d(6, 6, 2)
        self.rnn = nn.GRU(input_size=in_channels, hidden_size=128, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(328, num_classes)

    def _make_resnet_layer(self, kernel_size, stride, blocks, padding):
        layers = []
        for _ in range(blocks):
            downsampling = nn.Sequential(nn.MaxPool1d(2))
            layers.append(ResNet_1D_Block(self.planes, self.planes, kernel_size, stride, padding, downsampling))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.cat([conv(x) for conv in self.parallel_conv], dim=2)
        out = self.relu(self.bn1(out))
        out = self.conv1(out)
        out = self.block(out)
        out = self.relu(self.bn2(out))
        out = self.avgpool(out)
        out = out.reshape(out.size(0), -1)
        rnn_out, _ = self.rnn(x.permute(0, 2, 1))
        new_rnn_h = rnn_out[:, -1, :]
        new_out = torch.cat([out, new_rnn_h], dim=1)

        return self.fc(new_out)

# Train & Validate loop

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in tqdm(loader, desc="Train"):
        x, y = x.to(CFG.device), y.to(CFG.device)
        optimizer.zero_grad()
        out = model(x).squeeze(1)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def valid_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Valid"):
            x, y = x.to(CFG.device), y.to(CFG.device)
            out = model(x).squeeze(1)
            loss = criterion(out, y)
            total_loss += loss.item()
            prob = torch.sigmoid(out).cpu().numpy()
            preds.extend(prob)
            targets.extend(y.cpu().numpy())
    return total_loss / len(loader), preds, targets


def train_loop(fold_id, df):
    print(f"===== Fold {fold_id} =====")
    train_df = df[df.fold != fold_id].copy()
    val_df = df[df.fold == fold_id].copy()

    train_ds = ECGDataset(train_df, CFG.wave_dir)
    val_ds = ECGDataset(val_df, CFG.wave_dir)

    train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=CFG.batch_size, shuffle=False, num_workers=0)

    model = ECGNet(kernels=[3, 5, 7, 9]).to(CFG.device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)

    best_auc = 0
    for epoch in range(CFG.epochs):
        print(f"\nEpoch {epoch+1}/{CFG.epochs}")
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        val_loss, preds, targets = valid_epoch(model, val_loader, criterion)

        bin_preds = [int(p > 0.5) for p in preds]

        # metric
        acc = accuracy_score(targets, bin_preds)
        f1 = f1_score(targets, bin_preds)
        auc = roc_auc_score(targets, preds)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Val Acc: {acc:.4f} | F1: {f1:.4f} | AUROC: {auc:.4f}")

        if auc > best_auc:
            best_auc = auc
            torch.save(model.state_dict(), os.path.join(CFG.output_dir, f"ecgnet_fold{fold_id}.pt"))
            print("Best model saved.")

    # Save validation predictions
    val_df = val_df.reset_index(drop=True)
    val_df['pred'] = preds
    val_df.to_csv(os.path.join(CFG.output_dir, f"pred_df_f{fold_id}.csv"), index=False)
    return val_df



# Train

In [None]:

os.makedirs(CFG.output_dir, exist_ok=True)

# csv 불러오기
df = pd.read_csv(CFG.csv_path)
df = df.dropna(subset=['report_0', 'study_id'])

# stratifiedkfold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=CFG.seed)
df['fold'] = -1
for i, (_, val_idx) in enumerate(skf.split(df, df['report_0'])):
    df.loc[val_idx, 'fold'] = i

all_oof = []
for fold in CFG.folds:
    val_result = train_loop(fold, df)
    all_oof.append(val_result)

oof_df = pd.concat(all_oof).reset_index(drop=True)
oof_df.to_csv(os.path.join(CFG.output_dir, "oof_preds.csv"), index=False)
print("\n All folds complete.")

# Inference

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
model = ECGNet(kernels=[3, 5, 7, 9]).to(CFG.device)
model_dict = torch.load('./output/ecgnet_fold4.pt')
model.load_state_dict(model_dict)

<All keys matched successfully>

In [35]:
def predict(model, df, wave_dir, threshold=0.5, batch_size=32):
    ds = ECGDataset(df, wave_dir, test=True)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)

    preds = []
    model.eval()
    with torch.no_grad():
        for x in loader:
            x = x.to(CFG.device)
            output = model(x).squeeze()
            probs = torch.sigmoid(output).cpu().numpy()
            probs = np.atleast_1d(probs)
            preds.extend(probs)

    df = df.copy()
    df['pred_proba'] = preds
    df['pred_label'] = (df['pred_proba'] >= threshold).astype(int)
    return df

In [36]:
test_df = pd.read_csv('./test_fib.csv')  # your test set
pred_df = predict(model, test_df, './wave', threshold=0.5)
pred_df.to_csv('./output/test_preds.csv', index=False)

In [37]:
pred_df.report_0.value_counts()

report_0
0    3747
1     446
Name: count, dtype: int64

In [None]:
print("정확도:", (pred_df['report_0'] == pred_df['pred_label']).mean()) # filtering 한 결과

정확도: 0.9666110183639399


In [None]:
p = pd.read_csv('./output/test_preds_wo_pre.csv') # filtering 하기 전
print("정확도:", (p['report_0'] == p['pred_label']).mean())

정확도: 0.8936322442165514
