In [None]:
import torch 
from torch.utils.data import Dataset 
import os 
import dill as pickle


In [None]:
raw_data_path = 'Data/Raw/heart.pt'    #THIS NEEDS TO BE ABSOLUTE PATH
train_percent = 0.7
test_percent = 1 - train_percent
path_to_save_dir = 'Data/Datasets/Heart'  #path to save the data too


In [None]:
class TrainDataSet(Dataset):
    def __init__(self, raw_data_path, train_percent):
        X, y = torch.load(raw_data_path)
        
        num_rows = round(train_percent*len(X))
        self.rows = torch.randperm(len(X))[:num_rows]

        self.x = X[self.rows]
        self.y = y[self.rows]
        self.n_samples = len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.n_samples
    
class TestDataSet(Dataset):
    def __init__(self, raw_data_path, train_rows):
        X, y = torch.load(raw_data_path)
        
        rows = torch.ones(len(X), dtype=bool)
        rows[train_rows] = False

        self.x = X[rows]
        self.y = y[rows]
        self.n_samples = len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.n_samples


In [None]:
def save_dataset(dataset, path_to_save_dir, file_name):
    with open(os.path.join(path_to_save_dir,file_name), 'wb') as f:
        pickle.dump(dataset, f)


In [None]:
#save to non raw data to data Datasets
train_set = TrainDataSet(raw_data_path, train_percent)
test_set = TestDataSet(raw_data_path, train_percent, train_set.rows)

os.makedirs(path_to_save_dir, exist_ok=False)

save_dataset(train_set, path_to_save_dir, 'train.pkl')
save_dataset(test_set, path_to_save_dir, 'test.pkl')