In [1]:
from preprocess.preprocess import load_dataset, compute_label_agg, select_data, sample_class, undersample, no_to_augment
from preprocess.datasets import PTBDataset

from torch.utils.data import DataLoader, Dataset
import torch

from Augmentation.random_mask import generate_samples_rm
from Augmentation.random_noise import generate_samples_noising

In [2]:
# Define device for torch
use_cuda = True
print("CUDA is available:", torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

CUDA is available: True


# Set options

In [3]:
LOAD_DATASET = False
DATA_PATH = './trainloader.pt' # Stores path to save/load data to augment (only contains the class of interest)
CLS = 'HYP' # Class to augment
BATCH_SIZE = 64

# Load data to be augmented

In [4]:
if LOAD_DATASET:
    train_loader = torch.load(DATA_PATH)
else:
    path = './data/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3/'
        
    data, raw_labels = load_dataset(path)

    labels = compute_label_agg(raw_labels, path)

    data, labels, Y = select_data(data, labels)

    data, labels, Y = undersample(data, labels, Y)
    
    max_samples, min_samples = no_to_augment(labels, CLS)
    
    data, labels, Y = sample_class(data, labels, Y, CLS)

    ds = PTBDataset(data, labels, Y)

    train_loader = DataLoader(dataset=ds, batch_size=BATCH_SIZE)

    torch.save(train_loader, DATA_PATH)

## Generate augmented samples by random masking and saving data loaders

In [5]:
ds_rm = generate_samples_rm(train_loader, min_samples=min_samples, max_samples=max_samples)

In [6]:
train_ds, valid_ds, test_ds = torch.utils.data.random_split(ds_rm, [0.8, 0.1, 0.1])

train_loader_rm = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
valid_loader_rm = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader_rm = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

torch.save(train_loader_rm, './trainloader_augmented_rm.pt')
torch.save(valid_loader_rm, './validloader_augmented_rm.pt')
torch.save(test_loader_rm, './testloader_augmented_rm.pt')

## Generating augmented samples by noising and saving data loaders

In [8]:
ds_rn = generate_samples_noising(train_loader, batch_size=BATCH_SIZE, min_samples=min_samples, max_samples=max_samples)

In [9]:

train_ds, valid_ds, test_ds = torch.utils.data.random_split(ds_rn, [0.8, 0.1, 0.1])

train_loader_rn = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
valid_loader_rn = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader_rn = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

torch.save(train_loader_rn, './trainloader_augmented_rn.pt')
torch.save(valid_loader_rn, './validloader_augmented_rn.pt')
torch.save(test_loader_rn, './testloader_augmented_rn.pt')