In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp datasets/mics_datasets

# Implementation of torch Datasets based on the RIR databases
> Pytorch Datasets that provides signals from DB_microphone classes containing RIR signals.

In [None]:
#| export
from DataScience_exploration.datasets.mics_databases import ZeaRIR, MeshRIR
from DataScience_exploration.datasets.mics_databases import DB_microphones
from torch.utils.data import Dataset

import matplotlib.pyplot as plt
import torch
import numpy as np
from typing import List

## Pytorch Datasets
> Define what information from my databases I will provide to the model

### Dataset with a fixed environment

+ The Dataset fixes the microphones that are known. For example, I set 4 microphones at specific locations to meassure the acoustics of an environment.
+ In every iteration it returns a dictionary containing information from 1 microphone (position, signal and time samples)

In [None]:
class DSRirFixedEnv(Dataset):
    """ 
    Dataset with a fixed environment
    In this version I let the targeted microphone (to be predicted) to be any of the micros in the data
    (including those labeled as environment)
    """
    def __init__(self, 
                 mic_database: DB_microphones, 
                 ids_env: List[int],
                 ):
        super().__init__()
        self.db = mic_database
        self.ids_env = ids_env        

        # Environment microphones
        self.env = {}
        self.env['signal'] = [self.db.get_mic(i) for i in ids_env]
        self.env['time'] = [self.db.get_time(i) for i in ids_env]
        self.env['position'] = [self.db.get_pos(i) for i in ids_env]
        # Change to torch tensors
        self.env['signal'] = torch.from_numpy(np.stack(self.env['signal']).astype(np.float32))
        self.env['time'] = torch.from_numpy(np.stack(self.env['time']).astype(np.float32))
        self.env['position'] = torch.from_numpy(np.stack(self.env['position']).astype(np.float32))

    def __len__(self):
        return self.db.n_mics
    
    def __getitem__(self, idx):
        """
        In this version the environment is fixed, so in the __getitem__ 
        we only return the target 
        """       
        
        return dict(signal=self.db.get_mic(idx),
                    time=self.db.get_time(idx), 
                    position=self.db.get_pos(idx))
 
    def get_env(self):
        """
        Return the environment
        """
        return self.env
    
    def __str__(self):

        return ( 
            f"Pytorch Dataset: {self.__class__.__name__}\n"
            f"With length: {self.__len__()} \n"
            f"Environment (mics ids): {self.ids_env}"
            f"\n"+
            self.db.__str__()
        )
    

In [None]:
ds_Zea = DSRirFixedEnv(mic_database=ZeaRIR("./data", dataname="Balder", signal_start=0, signal_size=128),
                      ids_env=[10, 30, 50])
print()
print(ds_Zea)


Matched resources to download:
- BalderRIR.mat
Loading the resource ./data/ZeaRIR/raw/BalderRIR.mat ...

Pytorch Dataset: DSRirFixedEnv
With length: 100 
Environment (mics ids): [10, 30, 50]
Database: ZeaRIR
Download: ['BalderRIR.mat']
Load room: BalderRIR.mat
Path to raw resource: ./data/ZeaRIR/raw/BalderRIR.mat
Path to unpacked data folder: ./data/ZeaRIR/raw
Sampling frequency: 11250 Hz
Number of microphones: 100
Number of total time samples: 3623
Number of time samples selected: 128
Number of sources: 1
Signal start: 0
Signal size: 128
Source ID: 0


In [None]:

# Accesing an element
print(f"Length of dataset: {len(ds_Zea)}")
print("Position of Target (index 1). ")
print(f"using list indexing:   {ds_Zea[1]['position']} ") 
print(f"and using __getitem__: {ds_Zea.__getitem__(1)['position']} ")

# Print the environment
print()
print("Environment \nPositions:")
print(ds_Zea.get_env()['position'])

Length of dataset: 100
Position of Target (index 1). 
using list indexing:   [0.03 0.   0.  ] 
and using __getitem__: [0.03 0.   0.  ] 

Environment 
Positions:
tensor([[0.3000, 0.0000, 0.0000],
        [0.9000, 0.0000, 0.0000],
        [1.5000, 0.0000, 0.0000]])


### Dataset with random environment 

In [None]:
#| export
import math 
import random

In [None]:
#| export
class DS_random_pick(torch.utils.data.Dataset):
    def __init__(
        self,
        mic_database: DB_microphones, 
        n_ref_mics: int = 4,  # number of mics I will pick as my environment to interpolate
        max_combinations: int = 1000,  # number of maximum combinations
    ):
        super().__init__()
        self.db = mic_database
        self.n_ref_mics = n_ref_mics
        self.max_combinations = max_combinations

        # number of combinations without replacement of n elements in groups of r : n!/(r!*(n-r)!)
        n = self.db.n_mics
        r = self.n_ref_mics
        n_comb = int(math.factorial(n) / math.factorial(n - r) / math.factorial(r))
        self.len_comb_dataset = min(n_comb, self.max_combinations)

    def __len__(self):
        return self.len_comb_dataset

    def __getitem__(self, idx):
        ids = random.sample(range(self.db.n_mics), self.n_ref_mics + 1)

        signals = [self.db.get_mic(i) for i in ids]
        positions = [self.db.get_pos(i) for i in ids]
        times = [self.db.get_time(i) for i in ids]

        env = dict(
             signal=torch.from_numpy(np.stack(signals[1:]).astype(np.float32)),
             time=torch.from_numpy(np.stack(times[1:]).astype(np.float32)),
             position=torch.from_numpy(np.stack(positions[1:]).astype(np.float32)),
             )
        
        target = dict(
             signal=torch.from_numpy(np.array(signals[0]).astype(np.float32)),
             time=torch.from_numpy(np.array(times[0]).astype(np.float32)),
             position=torch.from_numpy(np.array(positions[0]).astype(np.float32)),
             )
                   
        return env, target

In [None]:
ds_Mesh = DS_random_pick(mic_database=MeshRIR(root="./data", dataname="S1", signal_start=0, signal_size=128, source_id=0),
                         n_ref_mics=4,
                         max_combinations=20)


env, target = ds_Mesh[1]
env_p, target_p = ds_Mesh.__getitem__(1)

print()
# Accesing an element
print(f"Length of dataset: {len(ds_Mesh)}")
print("Position of Target (index 1). ")
print(f"using list indexing:   {target['position']} ") 
print(f"and using __getitem__: {target_p['position']} ")

# Print the environment
print()
print("Environment \nPositions:")
print(env['position'])



Matched resources to download:
- S1-M3969_npy.zip
Unpacked folder ./data/MeshRIR/raw/S1-M3969_npy already exists. Skipping unpacking.

Length of dataset: 20
Position of Target (index 1). 
using list indexing:   tensor([0.1500, 0.3000, 0.2000]) 
and using __getitem__: tensor([0.5000, 0.1500, 0.0500]) 

Environment 
Positions:
tensor([[-0.2000,  0.2500,  0.1500],
        [ 0.4500,  0.1500,  0.0000],
        [-0.0500, -0.4500,  0.2000],
        [-0.1000,  0.4000, -0.1500]])


## Pytorch lightning Datamodules
> The pytorch lightning Datamodule organizes the torch ``Datasets`` with the operations that will have to be performed during the stages "fit" and "test". It also includes information about the ``Dataloader`` that will be used for the training.

In [None]:
#| export

import torch
import lightning.pytorch as L
from torch.utils.data import random_split, ConcatDataset, DataLoader
from typing import List


In [None]:

def ensure_list(x):
    if isinstance(x, Dataset):
        return [x]
    elif isinstance(x, list):
        return x
    elif x is None:
        return []
    else:
        raise TypeError(f"Expected Dataset or list of Datasets, got {type(x)}")
    
class DM_PL_DataModule(L.LightningDataModule):
    def __init__(self, 
                 ls_datasets_train: List[torch.utils.data.Dataset] = [], 
                 ls_datasets_test: List[torch.utils.data.Dataset] = [],
                 batch_size: int = 64, num_workers: int = 0, 
                 ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.ls_datasets_train = ensure_list(ls_datasets_train) 
        self.ls_datasets_test = ensure_list(ls_datasets_train)

    def setup(self, stage):
        if stage == "fit":
            self.ds_train, self.ds_val = random_split( ConcatDataset(self.ls_datasets_train), 
                                                        [0.8, 0.2])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.ds_test = ConcatDataset(self.ls_datasets_test)

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=False, collate_fn=None)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.batch_size)