In [1]:
import numpy as np
import torch
import os
from torch.utils.data import Dataset
from pathlib import Path
from matplotlib import pyplot as plt
from torchvision import transforms
import tqdm

In [2]:
class CANDataset(Dataset):
    def __init__(self, root_dir, is_train=True, transform=None):
        self.root_dir = Path(root_dir) / ('train' if is_train else 'val')
        self.is_train = is_train
        self.transform = transform
        self.total_size = len(os.listdir(self.root_dir))
            
    def __getitem__(self, idx):
        filename = f'{idx}.npz'
        filename = self.root_dir / filename
        data = np.load(filename)
        X, y = data['X'], data['y']
        X_tensor = torch.tensor(X, dtype=torch.float32)
        X_tensor = torch.unsqueeze(X_tensor, dim=0)
        y_tensor = torch.tensor(y, dtype=torch.long)
        if self.transform:
            X_tensor = self.transform(X_tensor)
        return X_tensor, y_tensor
    
    def __len__(self):
        return self.total_size

In [3]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in tqdm.tqdm(dataloader):
        this_batch_size = data.size()[0]
        weight = this_batch_size / dataloader.batch_size
        # Mean over batch, height and width, but not over the channels
        channels_sum += weight * torch.mean(data, dim=[0,2,3])
        channels_squared_sum += weight * torch.mean(data**2, dim=[0,2,3])
        num_batches += weight
    
    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [4]:
wavelet_name='mexh'
# data_dir = f'../Data/CHD_w29_s14_ID_Data/wavelet/{wavelet_name}/1/'
data_dir = '../../../Data/LISA/Federated_Data/Preprocessed_Data/Kia/1/'
train_dataset = CANDataset(root_dir=data_dir, is_train=True,)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128,
    shuffle=True, pin_memory=True, sampler=None
)

In [5]:
X, y = next(iter(train_dataloader))
X.shape, y.shape

(torch.Size([128, 1, 29, 29]), torch.Size([128]))

In [26]:
np.unique(y)

array([0, 1, 2])

In [5]:
means, stds = get_mean_and_std(train_dataloader)
print('Mean: ', means)
print('Std: ', stds)

100%|██████████| 1619/1619 [02:34<00:00, 10.49it/s]

Mean:  tensor([1712.4263,  141.0008,  110.5848,   97.2896,  188.9378,  124.5931,
         148.5182,   63.2015,  130.0784])
Std:  tensor([1107.1625,  105.1148,   78.9460,   79.3443,  147.8017,   93.7031,
         105.8413,   63.6263,  118.1454])





In [16]:
data_dir = '../Data/CHD_w29_s14_ID_Data/1/'
transform = transforms.Normalize(
                            mean=(126.8058,  10.4403,   8.1874,   7.2068,  13.9896,   9.2265,  10.9938, 4.6789, 9.6320), 
                            std =(510.3837,  67.7702,  43.0419,  53.2845,  79.1804,  60.3768,  60.1881, 48.7489,  70.4148))
train_dataset = CANDataset(root_dir=data_dir, is_train=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128,
    shuffle=True, num_workers=8, 
    pin_memory=True, sampler=None
)

In [6]:
filename = '../Data/CHD_w29_s14_ID_Data/1/train/1.npz'
data = np.load(filename)
X, y = data['X'], data['y']