In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wfdb
import seaborn as sns
import ast

In [None]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

path = '/home/naman21266/ptbxl_dataset/'
sampling_rate=100

# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where(Y.strat_fold != test_fold)]
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Simple1DUnet(nn.Module):
    def __init__(self, dim, dim_mults=(1, 2, 4, 8), channels=12):
        super().__init__()
        dims = [channels, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        for ind, (dim_in, dim_out) in enumerate(in_out):
            self.downs.append(nn.ModuleList([
                nn.Conv1d(dim_in, dim_out, 3, padding=1),
                nn.GELU(),
                nn.Conv1d(dim_out, dim_out, 3, padding=1),
                nn.GELU(),
                nn.Conv1d(dim_out, dim_out, 3, padding=1, stride=2)
            ]))

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            self.ups.append(nn.ModuleList([
                nn.Conv1d(dim_out * 2, dim_in, 3, padding=1),
                nn.GELU(),
                nn.Conv1d(dim_in, dim_in, 3, padding=1),
                nn.GELU(),
                nn.ConvTranspose1d(dim_in, dim_in, 4, stride=2, padding=1)
            ]))

        self.final_conv = nn.Conv1d(dim, channels, 1)

    def forward(self, x, time):
        t = self.time_mlp(time)

        h = []
        for conv1, gelu1, conv2, gelu2, downsample in self.downs:
            x = conv1(x) + t
            x = gelu1(x)
            x = conv2(x) + t
            x = gelu2(x)
            h.append(x)
            x = downsample(x)

        for conv1, gelu1, conv2, gelu2, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = conv1(x) + t
            x = gelu1(x)
            x = conv2(x) + t
            x = gelu2(x)
            x = upsample(x)

        return self.final_conv(x)

In [None]:
class GaussianDiffusion1D(nn.Module):
    def __init__(self, model, timesteps=1000):
        super().__init__()
        self.model = model
        self.timesteps = timesteps

        betas = torch.linspace(0.0001, 0.02, timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)

        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        noise = torch.randn_like(x_start) if noise is None else noise
        return (
            self.sqrt_recip_alphas_cumprod[t][:, None, None] * x_start +
            (1. - self.sqrt_recip_alphas_cumprod[t][:, None, None]) * noise
        )

    def forward(self, x_start, t, noise=None):
        x_noisy = self.q_sample(x_start, t, noise)
        predicted_noise = self.model(x_noisy, t)
        return predicted_noise, x_noisy

def p_losses(denoise_model, x_start, t, noise=None):
    noise = torch.randn_like(x_start) if noise is None else noise
    predicted_noise, x_noisy = denoise_model(x_start, t, noise)
    loss = F.mse_loss(predicted_noise, noise)
    return loss

def train_diffusion_model(model, data_loader, epochs=50):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for x_batch in data_loader:
            model.train()
            t = torch.randint(0, model.timesteps, (x_batch.shape[0],)).long().to(x_batch.device)
            loss = p_losses(model, x_batch, t)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}: Loss: {loss.item()}")