In [None]:
import haiku as hk
import matplotlib.pyplot as plt
import numpy as np
from importlib import reload

In [None]:
import torch 
import torchvision
from torchvision import transforms
import torchaudio
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

In [None]:
class PhaseMNIST(Dataset):
    
    def __init__(self, path, transform=None):
        
        with open(path + 'data.npy', 'rb') as f:
            self.data = np.load(f)
            
        with open(path + 'labels.npy', 'rb') as f:
            self.labels = np.load(f)
        
        self.length = len(self.labels)
        self.transform = transform
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        
        sample = (self.data[index], self.labels[index])
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        return sample

In [None]:
class ToTensor_onehot(object):
    """Convert sample to Tensors and one-hot encode the labels."""
    def __init__(self, n_classes=2):
        self.n_classes = n_classes

    def __call__(self, sample):
        x, y = sample
        # One-hot encoding
        y = np.eye(self.n_classes)[y]
        return ( torch.from_numpy(x), torch.tensor([y]) )
    
class Squeeze(object):
    """Remove the extra dimensions of the tensors."""
    def __init__(self, squeeze_x=True, squeeze_y=True):
        self.squeeze_x = squeeze_x
        self.squeeze_y = squeeze_y
    
    def __call__(self, sample):
        x, y = sample
        if self.squeeze_x:
            x = torch.squeeze(x)
        if self.squeeze_y:
            y = torch.squeeze(y)
        return ( x, y )

In [None]:
def build_dataloaders(data_dir, composed_transform, drop_length=0, batch_size=32):
    
    full_dataset = PhaseMNIST(path=data_dir, transform=composed_transform)
    
    train_len = int(0.5 * (len(full_dataset)-drop_length))
    test_len = len(full_dataset) - train_len - drop_length
    train_ds, test_ds, _ = torch.utils.data.random_split(full_dataset, [train_len, test_len, drop_length])
    
    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)
    
    return train_dataloader, test_dataloader