In [9]:
PATH = "E:/dataset/Hypertension/MIMIC-III/MIMIC3.csv"

In [10]:
# -------------------- Seed 고정 --------------------
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed()

# -------------------- 1. Augmentation --------------------
def jitter(x, sigma=0.01):
    return x + sigma * np.random.randn(*x.shape)

def scaling(x, sigma=0.05):
    return x * np.random.normal(1.0, sigma)

class PPGAugmentation:
    def __call__(self, x):
        x = jitter(x)
        x = scaling(x)
        return x

# -------------------- 2. Dual Encoder (CNN + Transformer 병렬) --------------------
class CNNBranch(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32), nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.Conv1d(64, out_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(out_dim), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
    def forward(self, x):
        x = self.net(x)
        return x.squeeze(-1)

class TransformerBranch(nn.Module):
    def __init__(self, d_model=128, nhead=4, num_layers=2):
        super().__init__()
        self.input_proj = nn.Conv1d(1, d_model, kernel_size=1)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=d_model*2)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        x = self.input_proj(x)
        x = x.permute(2, 0, 1)  # (L, B, d_model)
        x = self.transformer(x)
        return x.mean(dim=0)

class DualBranchEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_branch = CNNBranch(out_dim=128)
        self.trans_branch = TransformerBranch(d_model=128)

    def forward(self, x_aug):
        a = self.cnn_branch(x_aug)
        b = self.trans_branch(x_aug)
        return torch.cat([a, b], dim=1)  # (B, 256)

# -------------------- 3. Projector Head --------------------
class ProjHead(nn.Module):
    def __init__(self, input_dim=256, proj_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, proj_dim),
            nn.BatchNorm1d(proj_dim), nn.ReLU(),
            nn.Linear(proj_dim, proj_dim),
            nn.BatchNorm1d(proj_dim), nn.ReLU(),
            nn.Linear(proj_dim, proj_dim)
        )
    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

# -------------------- 4. Full SSL Model --------------------
class DualEncoderSSL(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = DualBranchEncoder()
        self.projector = ProjHead()

    def forward(self, v1, v2):
        h1 = self.encoder(v1)
        h2 = self.encoder(v2)
        return self.projector(h1), self.projector(h2)

# -------------------- 5. NT-Xent Loss --------------------
def nt_xent(z1, z2, temperature=0.5):
    B = z1.size(0)
    z = torch.cat([z1, z2], dim=0)
    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temperature
    mask = ~torch.eye(2*B, device=z.device, dtype=torch.bool)
    sim_neg = sim.masked_select(mask).view(2*B, -1)
    pos = torch.cat([sim.diagonal(B), sim.diagonal(-B)]).view(2*B, 1)
    logits = torch.cat([pos, sim_neg], dim=1)
    return F.cross_entropy(logits, torch.zeros(2*B, dtype=torch.long, device=z.device))

# -------------------- 6. Dataset --------------------
class SSLDataset(Dataset):
    def __init__(self, segments, transform):
        self.X = segments
        self.transform = transform
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        x = self.X[idx]
        aug1 = self.transform(x)
        aug2 = self.transform(x)
        return torch.from_numpy(aug1[None, :]).float(), torch.from_numpy(aug2[None, :]).float()

class ClassifierDataset(Dataset):
    def __init__(self, segments, labels):
        self.X = segments
        self.y = labels
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx][None, :]).float(), torch.tensor(self.y[idx]).long()

# -------------------- 7. Fine-tune Classifier --------------------
class TCNBranch(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 
                              padding=(kernel_size-1)//2 * dilation, dilation=dilation)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class InceptionTCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = TCNBranch(in_channels, out_channels//4, kernel_size=1)
        self.branch2 = TCNBranch(in_channels, out_channels//4, kernel_size=3)
        self.branch3 = TCNBranch(in_channels, out_channels//4, kernel_size=5)
        self.branch4 = TCNBranch(in_channels, out_channels//4, kernel_size=7)

    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        return torch.cat([out1, out2, out3, out4], dim=1)

class FineTuneClassifier(nn.Module):
    def __init__(self, encoder, num_classes=2):
        super().__init__()
        self.encoder = encoder
        self.inception_tcn = InceptionTCNBlock(1, 256)
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=4)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        with torch.no_grad():
            feat = self.encoder(x)
        feat = feat.unsqueeze(-1)
        feat = self.inception_tcn(feat)
        feat = feat.permute(2, 0, 1)
        feat = self.transformer(feat)
        feat = feat.squeeze(0)
        out = self.fc(feat)
        return out

# -------------------- 8. Main Pipeline --------------------
if __name__ == "__main__":
    # 데이터 로드

    df = pd.read_csv(PATH)
    signal_cols = [c for c in df.columns if c != 'Label']
    signals = df[signal_cols].values.astype(np.float32)
    labels = df['Label'].values.astype(np.int64)

    window_size, stride = 500, 250
    segments, seg_labels, groups = [], [], []
    for subj_id, (sig, lab) in enumerate(zip(signals, labels), start=1):
        for start in range(0, len(sig)-window_size+1, stride):
            segments.append(sig[start:start+window_size])
            seg_labels.append(lab)
            groups.append(subj_id)
    segments = np.stack(segments)
    seg_labels = np.array(seg_labels)
    groups = np.array(groups)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1단계: Self-Supervised 학습
    ssl_dataset = SSLDataset(segments, PPGAugmentation())
    ssl_loader = DataLoader(ssl_dataset, batch_size=64, shuffle=True)

    model_ssl = DualEncoderSSL().to(device)
    optimizer_ssl = torch.optim.Adam(model_ssl.parameters(), lr=1e-3)

    for epoch in range(1, 21):
        model_ssl.train(); total_loss = 0
        for v1, v2 in ssl_loader:
            v1, v2 = v1.to(device), v2.to(device)
            z1, z2 = model_ssl(v1, v2)
            loss = nt_xent(z1, z2)
            optimizer_ssl.zero_grad(); loss.backward(); optimizer_ssl.step()
            total_loss += loss.item()
        print(f"[SSL] Epoch {epoch:02d} — Loss: {total_loss/len(ssl_loader):.4f}")

    # 2단계: Fine-tuning
    encoder = model_ssl.encoder
    for p in encoder.parameters():
        p.requires_grad = False

    unique_subjects = np.unique(groups)
    np.random.shuffle(unique_subjects)
    folds = np.array_split(unique_subjects, 10)

    for fold_idx, test_subjects in enumerate(folds, start=1):
        is_test = np.isin(groups, test_subjects)
        X_tr, y_tr = segments[~is_test], seg_labels[~is_test]
        X_te, y_te = segments[is_test], seg_labels[is_test]

        train_loader = DataLoader(ClassifierDataset(X_tr, y_tr), batch_size=64, shuffle=True)
        test_loader = DataLoader(ClassifierDataset(X_te, y_te), batch_size=64, shuffle=False)

        model = FineTuneClassifier(encoder).to(device)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(1, 11):
            model.train()
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                loss = criterion(model(xb), yb)
                loss.backward()
                optimizer.step()

        model.eval(); preds, truths = [], []
        with torch.no_grad():
            for xb, yb in test_loader:
                xb = xb.to(device)
                preds.extend(model(xb).argmax(dim=1).cpu().numpy())
                truths.extend(yb.numpy())

        acc = accuracy_score(truths, preds)
        print(f"Fold {fold_idx} Accuracy: {acc:.4f}")




KeyboardInterrupt: 