In [1]:
#  key concept - sygnal normalization
#         curr_lead_data = data_dict[patient][lead_n, :]/np.abs(data_dict[patient][lead_n, :]).max()

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence, PackedSequence
# import torch.autograd as autograd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import random
import tqdm
import os
import pickle
# mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

In [3]:
unn_data_dict = pickle.load(open("./unn_data.pickle",'rb'))
icbeb_data_dict = pickle.load(open("./icbeb_data.pickle",'rb'))

In [4]:
if torch.cuda.device_count() > 0:
    torch.cuda.manual_seed_all(123)
torch.manual_seed(123)
pass

In [5]:
class ECG_Dataset(Dataset):
    def __init__(self, patients_dict: dict, labels_filepath: str, transform_strategy='cut'):
        self.transform_strategy = transform_strategy
        self.dataset = self.get_pairset(patients_dict, labels_filepath)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        if self.transform_strategy:
            sample = self.transform(sample)
        return sample
    
    def get_pairset(self, patients_dict: dict, labels_filepath: str):
        labels_df = pd.read_csv(labels_filepath, header=0).set_index('patient')
        return [(torch.Tensor(patients_dict[f"{patient}"].T),
                 torch.Tensor(labels_df.loc[patient].values.astype(bool).astype(int).astype(float))) 
                for patient in labels_df.index]
    def transform(self, sample):
        # TODO 1st make cut transform_strategy
        return sample
    
    def parse_labels_file(self, labels_filepath):
        return pd.read_csv(labels_filepath)

In [6]:
unn_dataset = ECG_Dataset(unn_data_dict, "./unn_labels.csv")
icbeb_dataset = ECG_Dataset(icbeb_data_dict, "./icbeb_labels.csv")

In [7]:
class Generator(nn.Module):
    def __init__(self, labels_dim, hidden_dim, latent_dim, 
                 device='cpu',
                 decoder_output_dim=12, 
                 num_layers=2, 
                 dropout=0.2, 
                 batch_first=True,
                 bidirectional=False,
                 batch_size=1,
                 unique_diagnosis_labels=[torch.from_numpy(np.array(_i, dtype=np.float32)) 
                           for _i in set([tuple(it[1].numpy().tolist()) for it in icbeb_dataset + unn_dataset])],
                 sample_len_collection=[5000, 7500, 10000, 12500, 15000]
                ):
        """
        Create a Generator object for time-series generation
        """
        super(Generator, self).__init__()
        self.labels_dim = labels_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.device = device
        self.decoder_output_dim = decoder_output_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.batch_first = batch_first
        self.bidirectional= bidirectional
        self.batch_size = batch_size
        self.unique_diagnosis_labels = unique_diagnosis_labels
        self.sample_len_collection = sample_len_collection
        self.lstm_output_mult = 1 + int(self.bidirectional)
        # Define our lstm layer
        self.lstm = nn.LSTM(input_size=self.labels_dim + self.latent_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=self.num_layers,
                            batch_first=self.batch_first,
                            dropout=self.dropout,
                            bidirectional=self.bidirectional,
                           ).to(self.device)
        # Define our decoder layer
        self.decoder = nn.Linear(self.hidden_dim * self.lstm_output_mult, self.decoder_output_dim).to(self.device)
        
    def init_hidden(self):
        return tuple([torch.zeros((self.num_layers, self.batch_size, self.hidden_dim), 
                                  dtype=torch.float32).to(self.device)
                      for _ in range(self.num_layers)])
    
    def forward(self, _input, _hidden):
        """
            if _category == None then _input already combined
        """
        input_combined = _input.to(self.device)
        
        raw_output, _hidden = self.lstm(input_combined, _hidden)
        if isinstance(raw_output, PackedSequence):
            raw_output, lengths = pad_packed_sequence(raw_output, batch_first=self.batch_first)
        decoded_output = self.decoder(raw_output)
        return decoded_output, _hidden
    
    def get_nllt_list(self, size):
        return [(torch.rand((1, self.latent_dim), dtype=torch.float32).expand(_len, -1).to(self.device), 
                 random.choice(self.unique_diagnosis_labels).unsqueeze(0).float().expand(_len, -1).to(self.device),
                 _len)
                for _len in random.choices(self.sample_len_collection, k=size)]
    
    def generate_seq_batch(self):
        noise_label_len_tuple_list = self.get_nllt_list(self.batch_size)
        sorted_nllt_list = sorted(noise_label_len_tuple_list, key=lambda x: x[2], reverse=True)

        noises = [x[0] for x in sorted_nllt_list]
        labels = [x[1] for x in sorted_nllt_list]
        lenghts = [x[2] for x in sorted_nllt_list]
        
        noises_padded = pad_sequence(noises, batch_first=self.batch_first)
        labels_padded = pad_sequence(labels, batch_first=self.batch_first)
#         print(f"{noises_padded.shape}, {labels_padded.shape}")
        labels_noises_padded = torch.cat([labels_padded, noises_padded], 2)
        
        labels_noises_packed = pack_padded_sequence(labels_noises_padded, lenghts, batch_first=self.batch_first)
        
        generated_seq, _ = self.forward(labels_noises_packed, self.init_hidden())
        
        return generated_seq, lenghts, labels_padded
    

In [8]:
# save_ecg_example(seq_batch[0].detach().numpy(), 'test_test')

# Discriminator

In [9]:
class Discriminator(nn.Module):
    def __init__(self, 
                 input_dim,
                 labels_dim, 
                 hidden_dim,
                 encoder_dim,
                 decoder_dim,
                 batch_size, 
                 device='cpu',
                 num_layers=2, 
                 dropout=0.2, 
                 batch_first=True,
                 bidirectional=False):
        """
        Create a Discriminator object for time-series clasification -- real or fake
        """
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.labels_dim = labels_dim
        self.hidden_dim = hidden_dim
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.batch_size = batch_size
        self.device = device
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional= bidirectional
        self.batch_first = batch_first
        self.lstm_output_mult = 1 + int(self.bidirectional)
        
        self.encoder = nn.Sequential(
                                    nn.Linear(self.labels_dim + self.input_dim, self.encoder_dim),
                                    nn.ReLU()
                                    ).to(self.device)
        # Define our lstm layer
        self.lstm = nn.LSTM(input_size=self.encoder_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=self.num_layers,
                            batch_first=self.batch_first,
                            dropout=self.dropout,
                            bidirectional=self.bidirectional,
                           ).to(self.device)
        # Define our decoder layer
        self.decoder = nn.Sequential(
                                    nn.Linear(self.hidden_dim * self.lstm_output_mult, self.decoder_dim),
                                    nn.Sigmoid()
                                    ).to(self.device)
        
    def init_hidden(self):
        return tuple([torch.zeros((self.num_layers, self.batch_size, self.hidden_dim), 
                                  dtype=torch.float32).to(self.device) 
                      for _ in range(self.num_layers)])
    
    def forward(self, _input, _hidden=None):
        """
            if _category == None then _input already combined
        """
        if _hidden is None:
            _hidden = self.init_hidden()
            
#         if _category:
#             input_combined = torch.cat((_category, _input), 2)
#         else:
#             input_combined = _input

        encoder_input, lengths = pad_packed_sequence(_input, batch_first=self.batch_first)

        encoded_output = self.encoder(encoder_input)

        packed_encoded_output = pack_padded_sequence(encoded_output, lengths, batch_first=self.batch_first)

        rnn_output, _hidden = self.lstm(packed_encoded_output, _hidden)
        if isinstance(rnn_output, PackedSequence):
            rnn_output, lengths = pad_packed_sequence(rnn_output, batch_first=self.batch_first)

        decoded_output = self.decoder(rnn_output)
        return decoded_output, lengths


# Training process

In [10]:
def save_ecg_example(gen_data: np.array, epoch_number, _title='12-lead ECG'):
    fig = plt.figure(figsize=(12,14),)
    for lead_n in range(gen_data.shape[1]):
    #             key concept - sygnal normalization
        curr_lead_data = gen_data[:, lead_n]
        ax = plt.subplot(4, 3, lead_n+1)
        plt.plot(curr_lead_data, label=f'lead_{lead_n+1}')
        plt.title(f'lead_{lead_n+1}')
    fig.suptitle(_title)
    plt.savefig(f'out/{epoch_number}.png', bbox_inches='tight')
    plt.close(fig)

In [11]:
def pad_batch_sequence(batch):
    sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
    sequences = [x[0] for x in sorted_batch]
    sequences_padded = pad_sequence(sequences, batch_first=True)
    lengths = [len(x) for x in sequences]
    # Don't forget to grab the labels of the *sorted* batch
    labels = [x[1].unsqueeze(0).expand(len(x[0]), -1) for x in sorted_batch]
    labels_padded = pad_sequence(labels, batch_first=True)
    return sequences_padded, lengths, labels_padded

In [12]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
# device = 'cpu'

In [13]:
# parameters for training
## general params
n_epochs = 10
batch_size = 24
lr = 1e-3
labels_dim = 7
lead_n = 12
## generator params
gen_h_dim = 128
gen_l_dim = 100
## discriminator params
dis_h_dim = 128
dis_encoder_h_dim = 128
# loss and dataloader setup
criterion = nn.BCELoss()
real_data_loader = DataLoader(icbeb_dataset, batch_size=batch_size, 
                              collate_fn=pad_batch_sequence, shuffle=True,)


In [14]:
G = Generator(labels_dim=labels_dim, 
              hidden_dim=gen_h_dim, 
              latent_dim=gen_l_dim, 
              batch_size=batch_size, 
              device=device, 
              decoder_output_dim=lead_n)
D = Discriminator(input_dim=lead_n, 
                  labels_dim=labels_dim, 
                  hidden_dim=dis_h_dim, 
                  encoder_dim=dis_encoder_h_dim, 
                  decoder_dim=1, 
                  batch_size=batch_size, 
                  device=device)

In [15]:
# if not os.path.exists('out/'):
os.makedirs('out/pictures', exist_ok=True)
os.makedirs('out/models', exist_ok=True)

In [22]:
! pip install torch==1.1.0

Collecting torch==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/69/60/f685fb2cfb3088736bafbc9bdbb455327bdc8906b606da9c9a81bae1c81e/torch-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (676.9MB)
[K    100% |████████████████████████████████| 676.9MB 84kB/s eta 0:00:011   11% |███▊                            | 79.3MB 2.1MB/s eta 0:04:40    45% |██████████████▋                 | 308.2MB 3.1MB/s eta 0:01:59    48% |███████████████▋                | 330.4MB 1.8MB/s eta 0:03:18    62% |████████████████████            | 423.9MB 5.1MB/s eta 0:00:50    65% |█████████████████████           | 446.1MB 1.2MB/s eta 0:03:13    69% |██████████████████████▎         | 471.0MB 2.5MB/s eta 0:01:22    70% |██████████████████████▌         | 475.6MB 1.4MB/s eta 0:02:26    70% |██████████████████████▊         | 480.2MB 1.5MB/s eta 0:02:09    71% |██████████████████████▉         | 482.4MB 1.4MB/s eta 0:02:21    75% |████████████████████████▏       | 512.2MB 984kB/s eta 0:02:48    77% |█████████████

In [16]:
""" ===================== TRAINING ======================== """


G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

# ones_label = torch.ones(mb_size, 1)
# zeros_label = torch.zeros(mb_size, 1)

writer = SummaryWriter()
for epoch in tqdm.tqdm_notebook(range(n_epochs),position=0):
    # sequences in true_data_loader already padded thanks to pad_batch_sequence function
    for real_seqs, real_lenghts, real_labels in tqdm.tqdm_notebook([it for it in real_data_loader][:5], position=1):
        
# ------------------------------ Discriminator step --------------------------------------
        # Generate fake sample,
        fake_seq_padded_batch, fake_lengths, fake_labels_padded_batch = G.generate_seq_batch()
        # Let's prepare our fake and real samples 
        # concat labels + sequence alongside
        fake_label_seq_padded_batch = torch.cat([fake_labels_padded_batch, fake_seq_padded_batch], 2)
        fake_packed_batch = pack_padded_sequence(fake_label_seq_padded_batch, fake_lengths, batch_first=True)
        # and the same for real ones
        real_label_seq_padded_batch = torch.cat([real_labels, real_seqs], 2)
        real_packed_batch = pack_padded_sequence(real_label_seq_padded_batch, real_lenghts, batch_first=True)
        # After that we can make predictions for our fake examples
        d_fake_predictions, fake_pred_lengths = D(fake_packed_batch)
        d_fake_target = torch.zeros_like(d_fake_predictions)
        # ... and real ones
        d_real_predictions, real_pred_lengths = D(real_packed_batch.to(device))
        d_real_target = torch.ones_like(d_real_predictions)
        # Now we can calculate loss for discriminator
        # TODO Before calc loss we need to make sure that sequence with vary length is ok for that 
        d_fake_loss = criterion(d_fake_predictions, d_fake_target)
        d_real_loss = criterion(d_real_predictions, d_real_target)
        d_loss = d_real_loss + d_fake_loss
        # And make backpropagation according to calculated loss
        d_loss.backward()
        D_optimizer.step()
        # Housekeeping - reset gradient
        D_optimizer.zero_grad()
        
# ---------------------------- Generator step ----------------------------------------------
        # Generate fake sample
        fake_seq_padded_batch, fake_lengths, fake_labels_padded_batch = G.generate_seq_batch()
        # concat labels + sequence alongside
        fake_label_seq_padded_batch = torch.cat([fake_labels_padded_batch, fake_seq_padded_batch], 2)
        fake_packed_batch = pack_padded_sequence(fake_label_seq_padded_batch, fake_lengths, batch_first=True)
        # After that we can make predictions for our fake examples
        d_fake_predictions, fake_pred_lengths = D(fake_packed_batch)
        g_target = torch.ones_like(d_fake_predictions)
        # Now we can calculate loss for generator
        g_loss = criterion(d_fake_predictions, g_target)
        # And make backpropagation according to calculated loss
        g_loss.backward()
        G_optimizer.step()
        # Housekeeping - reset gradient
        G_optimizer.zero_grad()
    # plot example each 100 epochs
    if epoch % 2 == 0:
        print(f'Epoch-{epoch}; D_loss: {d_loss.data.cpu().numpy()}; G_loss: {g_loss.data.cpu().numpy()}')
        torch.save({
            'epoch': epoch,
            "d_model": D,
            "d_loss": d_loss,
            "d_optimizer": D_optimizer,
            "g_model": G,
            "g_loss": g_loss,
            "g_optimizer": G_optimizer,
            }, f"./out/models/epoch_{epoch}_checkpoint.pkl")        
        with torch.no_grad():
            pass
            
#         feasible_label = np.zeros(shape=[mb_size, y_dim], dtype='float32')
#         feasible_label[:, np.random.randint(0, 10)] = 1.
#         feasible_label = Variable(torch.from_numpy(feasible_label))
#         samples = G(z, feasible_label).data.numpy()[:16]
        
        # TODO take visualize func here



HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

Epoch-0; D_loss: 1.4149549007415771; G_loss: 0.6764982342720032


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

Epoch-2; D_loss: 1.418543815612793; G_loss: 0.5454565286636353


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

KeyboardInterrupt: 

# Classifier
- normalise unn signal and overwrite related pickle 