In [164]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset,DataLoader
from data import load_traindata
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import random_split
from torch.optim.lr_scheduler import LambdaLR

mps


In [165]:
torch.manual_seed(42)

<torch._C.Generator at 0x10b203470>

In [166]:
num_subclasses = 100
epochs = 1000
seq_size = 1250
batch_size = 128
num_aug = 2
split_size=0.8

In [167]:
class Discriminator(nn.Module):
    def __init__(self, seq_size, num_filters=64):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=1, out_channels=num_filters, kernel_size=4, stride=2, padding=1)
        self.leakyRelu = nn.LeakyReLU(0.2)

        self.conv2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters*2, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm1d(num_filters*2)

        self.conv3 = nn.Conv1d(in_channels=num_filters*2, out_channels=num_filters*4, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm1d(num_filters*4)

        self.conv4 = nn.Conv1d(in_channels=num_filters*4, out_channels=num_filters*8, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm1d(num_filters*8)

        self.conv5 = nn.Conv1d(in_channels=num_filters*8, out_channels=num_filters*8, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm1d(num_filters*8)

        self.conv6 = nn.Conv1d(in_channels=num_filters*8, out_channels=1, kernel_size=4, stride=2, padding=1)
        self.op = nn.Linear(19,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.leakyRelu(x)

        x = self.conv2(x)
        x = self.bn1(x)
        x = self.leakyRelu(x)

        x = self.conv3(x)
        x = self.bn2(x)
        x = self.leakyRelu(x)

        x = self.conv4(x)
        x = self.bn3(x)
        x = self.leakyRelu(x)

        x = self.conv5(x)
        x = self.bn4(x)
        x = self.leakyRelu(x)

        x = self.conv6(x)
        x = self.leakyRelu(x)
        x = self.op(x)
        x = self.sigmoid(x)
        return x.squeeze(2)
# seq_size = 1250
# batch_size = 32
# model = Discriminator(seq_size)

# sample_input = torch.randn(batch_size, 1, seq_size)
# output = model(sample_input)

# print(output.shape)


In [168]:
class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, num_filters = 32):
        super(Generator, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv1d(in_channels, num_filters, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.enc2 = nn.Sequential(
            nn.Conv1d(num_filters, num_filters*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*2),
            nn.LeakyReLU(0.2)
        )
        self.enc3 = nn.Sequential(
            nn.Conv1d(num_filters*2, num_filters*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*4),
            nn.LeakyReLU(0.2)
        )
        self.enc4 = nn.Sequential(
            nn.Conv1d(num_filters*4, num_filters*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*8),
            nn.LeakyReLU(0.2)
        )
        self.enc5 = nn.Sequential(
            nn.Conv1d(num_filters*8, num_filters*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*8),
            nn.LeakyReLU(0.2)
        )
        self.enc6 = nn.Sequential(
            nn.Conv1d(num_filters*8, num_filters*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*8),
            nn.LeakyReLU(0.2)
        )
        self.enc7 = nn.Sequential(
            nn.Conv1d(num_filters*8, num_filters*16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(num_filters*16),
            nn.LeakyReLU(0.2)
        )

        # Decoder 
        self.dec1 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*16, num_filters*8, kernel_size=4, stride=2, padding=1, output_padding=1),
            nn.BatchNorm1d(num_filters*8),
            nn.Dropout1d(0.2),
            nn.ReLU(),

        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*8, num_filters*8, kernel_size=4, stride=2, padding=1, output_padding=1),  # Skip connection
            nn.BatchNorm1d(num_filters*8),
            nn.Dropout1d(0.2),
            nn.ReLU()
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*8, num_filters*8, kernel_size=4, stride=2, padding=1),  # Skip connection
            nn.BatchNorm1d(num_filters*8),
            nn.Dropout1d(0.2),
            nn.ReLU()
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*8, num_filters*4, kernel_size=4, stride=2, padding=1),  # Skip connection
            nn.BatchNorm1d(num_filters*4),
            nn.ReLU()
        )
        self.dec5 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*4, num_filters*2, kernel_size=4, stride=2, padding=1),  # Skip connection
            nn.BatchNorm1d(num_filters*2),
            nn.ReLU()
        )
        self.dec6 = nn.Sequential(
            nn.ConvTranspose1d(num_filters*2, num_filters, kernel_size=4, stride=2, padding=1, output_padding=1),  # Skip connection
            nn.BatchNorm1d(num_filters),
            nn.ReLU()
        )
        self.dec7 = nn.ConvTranspose1d(num_filters, out_channels, kernel_size=4, stride=2, padding=1)  # Skip connection, no BN
        self.final_activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        # Encoder
        x = x.unsqueeze(1)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        #print(e7.shape, "END OF ENCODER")

        # Decoder with skip connections
        d1 = e6 + self.dec1(e7)
        d2 = e5 + self.dec2(d1)  
        d3 = e4 + self.dec3(d2)
        d4 = e3 + self.dec4(d3)
        d5 = e2 + self.dec5(d4)
        d6 = e1 + self.dec6(d5)
        d7 = self.dec7(d6)
        return self.final_activation(d7)

In [169]:
def time_shifting(ecg_data, w=0.08):
    if not (0 <= w <= 1):
        raise ValueError("w must be between 0 and 1.")
    T = ecg_data.shape[0]
    shift_length = int(w * T)
    # Perform a circular shift (rotation)
    ecg_data = np.roll(ecg_data, shift=shift_length, axis=0)
    return ecg_data

In [170]:
def time_masking(ecg_data, w=0.05):
    if not (0 <= w <= 1):
        raise ValueError("w must be between 0 and 1.")
    T = ecg_data.shape[0]
    mask_length = int(w * T)
    ts = np.random.randint(0, T - mask_length + 1)
    ecg_data[ts:ts + mask_length, :] = 0
    return ecg_data

In [171]:
def augment(X, num_masks=num_aug):
    augmented_X = []
    for i in range(len(X)):
        ecg_data = X[i].clone()  # To avoid modifying the original data
        
        # Append the original data
        augmented_X.append(ecg_data)
        # Apply rotation num_masks times and append to augmented dataset
        for _ in range(num_masks):
            augmented_data = time_shifting(ecg_data.numpy())  # Apply time masking to numpy array
            masked_data = time_masking(ecg_data.numpy())
            augmented_X.append(torch.from_numpy(augmented_data))  # Convert back to torch tensor
            augmented_X.append(torch.from_numpy(masked_data))  # Convert back to torch tensor
    
    # Convert augmented lists back to tensors
    augmented_X = torch.stack(augmented_X)
    
    return augmented_X

torch.Size([80, 1250, 12])


In [190]:
X, _ = load_traindata(num_subclasses)
X = torch.tensor(X, dtype=torch.float32)
X = X.reshape(int(num_subclasses), int(5000 / seq_size), seq_size, 12)
X = X.reshape(int(num_subclasses * int(5000 / seq_size)), seq_size, 12)
train_size = int(split_size*len(X))
test_size = len(X) - train_size
X,X_test = random_split(X,[train_size, test_size])
X = augment(X)
X_input = X[:, :, 0]  # First channel
Y_target = X[:, :, 1]  # Second channel


X_test_tensors = [X_test.dataset[idx] for idx in range(len(X_test))] 
X_test_tensor = torch.stack(X_test_tensors) 
X_t = X_test_tensor[:, :, 0]  
Y_t = X_test_tensor[:, :, 1]

(9514, 28)


In [94]:
dataset = TensorDataset(X_input, Y_target)
test_dataset = TensorDataset(X_t, Y_t)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator(seq_size).to(device)

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5,0.9))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5,0.9))
print(sum(p.numel() for p in generator.parameters())/1e6, 'M parameters for Generator')
print(sum(p.numel() for p in discriminator.parameters())/1e6, 'M parameters for Discriminator')

2.448897 M parameters for Generator
1.743317 M parameters for Discriminator


In [95]:
def plotWave(X, Y):
    plt.figure(figsize=(20, 6))
    plt.plot(X.squeeze(0).detach().cpu().numpy(), color='blue', label='X')
    plt.plot(Y.squeeze(0).detach().cpu().numpy(), color='red', label='Y')
    plt.legend()
    plt.show()

In [96]:
g_losses, d_losses = [], []

In [97]:
for epoch in range(epochs):
    for real_1, real_2 in dataloader:
        real_1, real_2 = real_1.to(torch.float32),real_2.to(torch.float32)
        real_1, real_2 = real_1.to(device), real_2.to(device)
        batch_size = real_1.size(0)
        
        # Train Discriminator
        optimizer_d.zero_grad()
        real_labels = torch.full((batch_size, 1), 0.9).to(device)
        fake_labels = torch.full((batch_size, 1), 0.1).to(device)
        
        fake_2 = generator(real_1)
        fake_output = discriminator(fake_2.detach())
        fake_loss = criterion(fake_output, fake_labels)


        real_2  = real_2.unsqueeze(1)
        real_output = discriminator(real_2)
        real_loss = criterion(real_output, real_labels)
        
        d_loss = (real_loss + fake_loss)/2.0
        d_loss.backward()
        optimizer_d.step()
        
        # Train Generator
        for _ in range(2):
            optimizer_g.zero_grad()
            fake_output = discriminator(fake_2.detach())
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward(retain_graph=True)
            optimizer_g.step()
        
    g_losses.append(g_loss.item())
    d_losses.append(d_loss.item())
    if (epoch+1)%10==0:
        print(epoch)
        for real_1, real_2 in dataloader:
            real_1, real_2 = real_1.to(torch.float32),real_2.to(torch.float32)
            real_1, real_2 = real_1.to(device), real_2.to(device)
            fake_2 = generator(real_1)
            plotWave(fake_2[0],real_2[0])
            plotWave(fake_2[12],real_2[12]) #random output
            break
    print(f"Epoch {epoch+1}/{epochs}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")


Epoch 1/100, D Loss: 0.3261021375656128, G Loss: 2.069563388824463
Epoch 2/100, D Loss: 0.325212299823761, G Loss: 2.122236728668213
Epoch 3/100, D Loss: 0.32515043020248413, G Loss: 2.0601320266723633


KeyboardInterrupt: 