## DataLoader in SpeechBrain

1. SpeechBrain Data Loading Pipeline follows Pytorch Data Loading Pipeline.
2. The Pytorch Data Loading Pipeline consists of the following argument :
  * **Dataset** : It loads one data point at a time
  * **Collation Function** : This converts the dataset into Pytorch Tensor batches
  * **Sampler** : decides how the dataset should be iterated
  * **Data Loader** : This takes the above mentioned and other arguments like batch_size and creates instances of data which are iterated during training
  This link contains

3. You can also **directly load the data** to the brain.fit ( ) function and specify the data loader options (such as batch size) in train_loader_kwargs argument taken from the yaml file. 

>For example :
brain.fit(range(hparams["N_epochs"]), data, train_loader_kwargs=hparams["dataloader_options"])

In [None]:
pip install speechbrain

In [2]:
import torch
import speechbrain as sb

In [3]:
class SimpleBrain(sb.Brain):

  # This method  take the batch and computes the forward pass
  
  def compute_forward(self, batch, stage):
    #print(batch[0])
    return self.modules.model(batch[0])

  # This method takes the  predictions and labels to minimize the 
  #loss function and updates the weights

  def compute_objectives(self, predictions, batch, stage):
    return torch.nn.functional.l1_loss(predictions, batch[1])


In [12]:
# data with random tensor just for demonstration
data=[]
for i in range(10):
  data.append([torch.rand(10, 10), torch.rand(10, 10)])



In [13]:
data_new = torch.utils.data.DataLoader(data,batch_size=4, shuffle=False)

In [14]:
#Define a torch model consisiting of single linear layer
model = torch.nn.Linear(in_features=10, out_features=10)

#Brain class is defined by taking the model,optimiser class
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1),)

#Use the fit method to train the model 
#brain.fit(range(10), data)
brain.fit(range(10),  data_new)

100%|██████████| 3/3 [00:00<00:00, 135.48it/s, train_loss=0.607]
100%|██████████| 3/3 [00:00<00:00, 141.62it/s, train_loss=0.537]
100%|██████████| 3/3 [00:00<00:00, 194.83it/s, train_loss=0.475]
100%|██████████| 3/3 [00:00<00:00, 171.56it/s, train_loss=0.423]
100%|██████████| 3/3 [00:00<00:00, 164.90it/s, train_loss=0.381]
100%|██████████| 3/3 [00:00<00:00, 166.45it/s, train_loss=0.35]
100%|██████████| 3/3 [00:00<00:00, 188.20it/s, train_loss=0.327]
100%|██████████| 3/3 [00:00<00:00, 185.42it/s, train_loss=0.311]
100%|██████████| 3/3 [00:00<00:00, 207.32it/s, train_loss=0.299]
100%|██████████| 3/3 [00:00<00:00, 185.44it/s, train_loss=0.291]
