In [1]:
import torch
from torch import nn
import numpy as np
from wfdb import rdrecord

In [2]:
batch_size = 128
d_learning_rate = 0.0001
g_learning_rate = 0.0001
g_lambda = 100 
num_epochs = 100
ngpu=1
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [3]:
class Discriminator(nn.Module):
    """D"""
    def __init__(self, dropout_drop=0.5):
        super().__init__()
        # Define convolution operations.
        # (#input channel, #output channel, kernel_size, stride, padding)
        # in : 16384 x 2
        negative_slope = 0.03
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=31, stride=2, padding=15)   # out : 8192 x 32
        self.vbn1 = nn.BatchNorm1d(32)
        self.lrelu1 = nn.LeakyReLU(negative_slope)
        self.conv2 = nn.Conv1d(32, 64, 31, 2, 15)  # 4096 x 64
        self.vbn2 = nn.BatchNorm1d(64)
        self.lrelu2 = nn.LeakyReLU(negative_slope)
        self.conv3 = nn.Conv1d(64, 64, 31, 2, 15)  # 2048 x 64
        self.dropout1 = nn.Dropout(dropout_drop)
        self.vbn3 = nn.BatchNorm1d(64)
        self.lrelu3 = nn.LeakyReLU(negative_slope)
        self.conv4 = nn.Conv1d(64, 128, 31, 2, 15)  # 1024 x 128
        self.vbn4 = nn.BatchNorm1d(128)
        self.lrelu4 = nn.LeakyReLU(negative_slope)
        self.conv5 = nn.Conv1d(128, 128, 31, 2, 15)  # 512 x 128
        self.vbn5 = nn.BatchNorm1d(128)
        self.lrelu5 = nn.LeakyReLU(negative_slope)
        self.conv6 = nn.Conv1d(128, 256, 31, 2, 15)  # 256 x 256
        self.dropout2 = nn.Dropout(dropout_drop)
        self.vbn6 = nn.BatchNorm1d(256)
        self.lrelu6 = nn.LeakyReLU(negative_slope)
        self.conv7 = nn.Conv1d(256, 256, 31, 2, 15)  # 128 x 256
        self.vbn7 = nn.BatchNorm1d(256)
        self.lrelu7 = nn.LeakyReLU(negative_slope)
        self.conv8 = nn.Conv1d(256, 512, 31, 2, 15)  # 64 x 512
        self.vbn8 = nn.BatchNorm1d(512)
        self.lrelu8 = nn.LeakyReLU(negative_slope)
        self.conv9 = nn.Conv1d(512, 512, 31, 2, 15)  # 32 x 512
        self.dropout3 = nn.Dropout(dropout_drop)
        self.vbn9 = nn.BatchNorm1d(512)
        self.lrelu9 = nn.LeakyReLU(negative_slope)
        self.conv10 = nn.Conv1d(512, 1024, 31, 2, 15)  # 16 x 1024
        self.vbn10 = nn.BatchNorm1d(1024)
        self.lrelu10 = nn.LeakyReLU(negative_slope)
        self.conv11 = nn.Conv1d(1024, 2048, 31, 2, 15)  # 8 x 1024
        self.vbn11 = nn.BatchNorm1d(2048)
        self.lrelu11 = nn.LeakyReLU(negative_slope)
        # 1x1 size kernel for dimension and parameter reduction
        self.conv_final = nn.Conv1d(2048, 1, kernel_size=1, stride=1)  # 8 x 1
        self.lrelu_final = nn.LeakyReLU(negative_slope)
        self.fully_connected = nn.Linear(in_features=16, out_features=1)  # 1
        self.sigmoid = nn.Sigmoid()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_normal_(m.weight.data)

    def forward(self, x):
        """
        Forward pass of discriminator.
        Args:
            x: batch
        """
        x = x.unsqueeze(1)
        # train pass
        x = self.conv1(x)
        x = self.vbn1(x)
        x = self.lrelu1(x)
        x = self.conv2(x)
        x = self.vbn2(x)
        x = self.lrelu2(x)
        x = self.conv3(x)
        x = self.dropout1(x)
        x = self.vbn3(x)
        x = self.lrelu3(x)
        x = self.conv4(x)
        x = self.vbn4(x)
        x = self.lrelu4(x)
        x = self.conv5(x)
        x = self.vbn5(x)
        x = self.lrelu5(x)
        x = self.conv6(x)
        x = self.dropout2(x)
        x = self.vbn6(x)
        x = self.lrelu6(x)
        x = self.conv7(x)
        x = self.vbn7(x)
        x = self.lrelu7(x)
        x = self.conv8(x)
        x = self.vbn8(x )
        x = self.lrelu8(x)
        x = self.conv9(x)
        x = self.dropout3(x)
        x = self.vbn9(x )
        x = self.lrelu9(x)
        x = self.conv10(x)
        x = self.vbn10(x )
        x = self.lrelu10(x)
        x = self.conv11(x)
        x = self.vbn11(x )
        x = self.lrelu11(x)
        x = self.conv_final(x)
        x = self.lrelu_final(x)
        # reduce down to a scalar value
        x = torch.squeeze(x)
        x = self.fully_connected(x)
        # return self.sigmoid(x)
        return x



In [4]:
class Generator(nn.Module):
    """G"""
    def __init__(self):
        super().__init__()
        # size notations = [batch_size x feature_maps x width] (height omitted - 1D convolutions)
        # encoder gets a noisy signal as input
        self.enc1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=32, stride=2, padding=15)   # out : [B x 16 x 1000]
        self.enc1_nl = nn.PReLU()  # non-linear transformation after encoder layer 1
        self.enc2 = nn.Conv1d(16, 32, 32, 2, 15)  # [B x 32 x 500]
        self.enc2_nl = nn.PReLU()
        self.enc3 = nn.Conv1d(32, 32, 32, 2, 15)  # [B x 32 x 250]
        self.enc3_nl = nn.PReLU()
        self.enc4 = nn.Conv1d(32, 64, 32, 2, 15)  # [B x 64 x 1024]
        self.enc4_nl = nn.PReLU()
        self.enc5 = nn.Conv1d(64, 64, 32, 2, 15)  # [B x 64 x 125]
        self.enc5_nl = nn.PReLU()
        self.enc6 = nn.Conv1d(64, 128, 31, 2, 15)  # [B x 128 x 63]
        self.enc6_nl = nn.PReLU()
        self.enc7 = nn.Conv1d(128, 128, 31, 2, 15)  # [B x 128 x 32]
        self.enc7_nl = nn.PReLU()
        self.enc8 = nn.Conv1d(128, 256, 32, 2, 15)  # [B x 256 x 16]
        self.enc8_nl = nn.PReLU()
        self.enc9 = nn.Conv1d(256, 256, 32, 2, 15)  # [B x 256 x 8]
        self.enc9_nl = nn.PReLU()
        
        # decoder generates an enhanced signal
        # each decoder output are concatenated with homolgous encoder output,
        # so the feature map sizes are doubled
        self.dec8 = nn.ConvTranspose1d(512, 256, 32, 2, 15)  # [B x 256 x 8]
        self.dec8_nl = nn.PReLU()
        self.dec7 = nn.ConvTranspose1d(512, 128, 32, 2, 15)  # [B x 128 x 16]
        self.dec7_nl = nn.PReLU()
        self.dec6 = nn.ConvTranspose1d(256, 128, 31, 2, 15)  # [B x 128 x 32]
        self.dec6_nl = nn.PReLU()
        self.dec5 = nn.ConvTranspose1d(256, 64, 32, 2, 15)  # [B x 64 x 63]
        self.dec5_nl = nn.PReLU()
        self.dec4 = nn.ConvTranspose1d(128, 64, 33, 2, 15)  # [B x 64 x 125]
        self.dec4_nl = nn.PReLU()
        self.dec3 = nn.ConvTranspose1d(128, 32, 32, 2, 15)  # [B x 32 x 2048]
        self.dec3_nl = nn.PReLU()
        self.dec2 = nn.ConvTranspose1d(64, 32, 32, 2, 15)  # [B x 32 x 4096]
        self.dec2_nl = nn.PReLU()
        self.dec1 = nn.ConvTranspose1d(64, 16, 32, 2, 15)  # [B x 16 x 8192]
        self.dec1_nl = nn.PReLU()
        self.dec_final = nn.ConvTranspose1d(32, 1, 32, 2, 15)  # [B x 1 x 16384]
        self.dec_tanh = nn.Tanh()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                nn.init.xavier_normal_(m.weight.data)

    def forward(self, x):
        """
        Forward pass of generator.
        Args:
            x: input batch (signal)
            z: latent vector
        """
        x = x.unsqueeze(1)
        ### encoding step
        e1 = self.enc1(x)
        e2 = self.enc2(self.enc1_nl(e1))
        e3 = self.enc3(self.enc2_nl(e2))
        e4 = self.enc4(self.enc3_nl(e3))
        e5 = self.enc5(self.enc4_nl(e4))
        e6 = self.enc6(self.enc5_nl(e5))
        e7 = self.enc7(self.enc6_nl(e6))
        e8 = self.enc8(self.enc7_nl(e7))
        e9 = self.enc9(self.enc8_nl(e8))
        
        # c = compressed feature, the 'thought vector'
        c = self.enc9_nl(e9)
        z= torch.randn(c.shape).to(device)
        encoded = torch.cat((c, z), dim=1)

        ### decoding step
        d8 = self.dec8(encoded)
        d8_c = self.dec8_nl(torch.cat((d8, e8), dim=1))
        d7 = self.dec7(d8_c)
        d7_c = self.dec7_nl(torch.cat((d7, e7), dim=1))
        d6 = self.dec6(d7_c)
        d6_c = self.dec6_nl(torch.cat((d6, e6), dim=1))
        d5 = self.dec5(d6_c)
        d5_c = self.dec5_nl(torch.cat((d5, e5), dim=1))
        d4 = self.dec4(d5_c)
        d4_c = self.dec4_nl(torch.cat((d4, e4), dim=1))
        d3 = self.dec3(d4_c)
        d3_c = self.dec3_nl(torch.cat((d3, e3), dim=1))
        d2 = self.dec2(d3_c)
        d2_c = self.dec2_nl(torch.cat((d2, e2), dim=1))
        d1 = self.dec1(d2_c)
        d1_c = self.dec1_nl(torch.cat((d1, e1), dim=1))
        out = self.dec_tanh(self.dec_final(d1_c))
        return out.reshape((out.shape[0], out.shape[-1]))

In [5]:
import torch.optim as optim
discriminator = torch.nn.DataParallel(Discriminator().to(device))  # use GPU
print(discriminator)
print('Discriminator created')

generator = torch.nn.DataParallel(Generator().to(device))
print(generator)
print('Generator created')

g_optimizer = optim.Adam(generator.parameters(), lr=g_learning_rate, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=d_learning_rate, betas=(0.5, 0.999))

DataParallel(
  (module): Discriminator(
    (conv1): Conv1d(1, 32, kernel_size=(31,), stride=(2,), padding=(15,))
    (vbn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu1): LeakyReLU(negative_slope=0.03)
    (conv2): Conv1d(32, 64, kernel_size=(31,), stride=(2,), padding=(15,))
    (vbn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu2): LeakyReLU(negative_slope=0.03)
    (conv3): Conv1d(64, 64, kernel_size=(31,), stride=(2,), padding=(15,))
    (dropout1): Dropout(p=0.5)
    (vbn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu3): LeakyReLU(negative_slope=0.03)
    (conv4): Conv1d(64, 128, kernel_size=(31,), stride=(2,), padding=(15,))
    (vbn4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lrelu4): LeakyReLU(negative_slope=0.03)
    (conv5): Conv1d(128, 128, kernel_size=(31,), stride=(2,), padding=(15,))
  

In [6]:
import random
def choose_noise(signal):
    noise_type=random.choice(['bw', 'em', 'ma'])
    noise_channel=random.choice([0,1])
    noise_start= random.choice(np.arange(2000))
    noise_end=random.choice(np.arange(noise_start,2000))
    noise_from=random.choice(np.arange(648000))
    alpha= random.uniform(0.5, 1)
    record = rdrecord("/aii/sophiap/ECG_detection/Noise_Samples/"+noise_type)
    signal[noise_start:noise_end] += alpha*record.p_signal[noise_from+noise_start: noise_from+noise_end, noise_channel]
    return signal

class ecg_data():
    def __init__(self, original, clean, transform=None):
        self.transform = transform
        self.original=original
        self.clean=clean

    def __len__(self):
        return self.clean.shape[0]

    def __getitem__(self, idx):
        noise=choose_noise(self.original[idx, :])
        denoise = self.clean[idx, :]
        if self.transform:
            noise = self.transform(noise)
            denoise = self.transform(self.clean[idx, :])

        return denoise, noise

In [7]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor()])
original=np.load("original_data.npy")
clean=np.load("clean_data.npy")
# create Tensor datasets
train_data = ecg_data(original, clean)
# dataloaders
batch_size = 64
# make sure to SHUFFLE your data
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, num_workers=4)

In [None]:
print('Starting Training...')
plot_gen=[];plot_disc=[]
for epoch in range(num_epochs):
    temp_gen=0;temp_disc=0;n=0
    for data, label in train_loader:
        clean_signal= random.uniform(0, 0.5)
        noise_signal= random.uniform(0.95, 1)
         ##### TRAIN D #####
        # TRAIN D to recognize clean audio as clean
        # training batch pass
        data = data.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)
        data, label = data.to(device), label.to(device)
        outputs = discriminator(label)
        clean_loss = torch.mean((outputs - clean_signal) ** 2)  # L2 loss - we want them all to be 1
        # TRAIN D to recognize generated audio as noisy
        generated_outputs = generator(data)
        outputs = discriminator(data)
        noisy_loss1 = torch.mean((outputs - noise_signal) ** 2)
        outputs = discriminator(generated_outputs)
        noisy_loss2 = torch.mean((outputs - noise_signal) ** 2)
        d_loss = (clean_loss + noisy_loss1 + noisy_loss2)/3

        # back-propagate and update
        discriminator.zero_grad()
        d_loss.backward()
        d_optimizer.step()  # update parameters

        ##### TRAIN G #####
        # TRAIN G so that D recognizes G(z) as real
        generated_outputs = generator(data)
        outputs = discriminator(generated_outputs)
        g_loss_ = 0.5 * torch.mean((outputs - clean_signal) ** 2)
        l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(label)))
        g_cond_loss = g_lambda * torch.mean(l1_dist)  # conditional loss
        g_loss = g_loss_ + g_cond_loss

        # back-propagate and update
        generator.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        temp_gen+=g_loss.item(); temp_disc+=d_loss.item();n+=1
        plot_gen.append(temp_gen/n);plot_disc.append(temp_disc/n)
    print("GLOSS="+str(temp_gen[-1])+", DLOSS="+str(temp_disc[-1]))

Starting Training...
