## What is Brain Class

 

1.   The Brain class is the most important part of SpeechBrain.
2.   It is used to perform the training loop (iterate through the dataset and update the model parameters) by using the fit() method. 
3. It abstracts away the details of the data loops.

4. To use the fit( ) method, the following two methods need to be defined :
> def compute_forward ( self, batch, stage ) : computes the forward pass and generate the model predictions
> def compute_objectives (self, predictions, batch, stage ) : contains the 
loss function used to find the gradient 

## Parameters required to Define Brain Class

In order to define the Brain Class, we require five arguments :
1. **modules**  : It takes the model and converts to Torch ModelDict. It makes sure to convert all parameters to same device, for calling train( ) and eval ( )
2. **opt_class** : This argument takes the pytorch optimizer that is to be used. It can be defined in the HyperPyYaml file and can be passed as an argument.
3.  **hparams** : This argument takes the set of hyperparameters that need to be defined separately.
4.  **run_opts** : This argument handles the execution details of training such as the training device, distributed execution etc.
5.  **checkpointer** : This is used to save various details relevant to saving the model like parameters, training progress etc

Example of brain class
> brain = SimpleBrain ( {"model": model},  hparams['opt_class'],  hparams,  run_opts={'device':device}, )


## Using fit( ) method

The fit ( ) method performs the training by taking in arguments such as number of epochs, train data, validation data and parameters related to data loader such as batch_size. 
The following is as example of how to call the fit ( ) method

> brain.fit ( range(hparams["N_epochs"]), data, train_loader_kwargs=hparams["dataloader_options"]  )


## Running an experiment
To run an experiment you have to run the following command after activating the speechbrain environment, the train file contains the Brain class, compute_forward, compute_ objectives and the data loader, while the yaml file contains all the hyperparameters required for training the model  :
	
>  python train.py hyp.yaml



Install SpeechBrain

In [None]:
!pip install speechbrain

Import Libraries

In [None]:
import torch
import speechbrain as sb

Define the SimpleBrain Class. You need to define two methods here
1. compute_forward ( ) 
2. compute_objectives ( )

In [None]:
class SimpleBrain(sb.Brain):

  # This method  take the batch and computes the forward pass
  
  def compute_forward(self, batch, stage):
    return self.modules.model(batch["input"])

  # This method takes the  predictions and labels to minimize the 
  #loss function and updates the weights

  def compute_objectives(self, predictions, batch, stage):
    return torch.nn.functional.l1_loss(predictions, batch["target"])


Define the model and brain.fit( )

In [None]:
#Define a torch model consisiting of single linear layer
model = torch.nn.Linear(in_features=10, out_features=10)


# data with random tensor just for demonstration
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]

#Brain class is defined by taking the model,optimiser class
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1),)

#Use the fit method to train the model 
brain.fit(range(10), data)

100%|██████████| 1/1 [00:00<00:00,  6.24it/s, train_loss=0.658]
100%|██████████| 1/1 [00:00<00:00, 253.95it/s, train_loss=0.63]
100%|██████████| 1/1 [00:00<00:00, 380.68it/s, train_loss=0.602]
100%|██████████| 1/1 [00:00<00:00, 395.95it/s, train_loss=0.576]
100%|██████████| 1/1 [00:00<00:00, 382.66it/s, train_loss=0.55]
100%|██████████| 1/1 [00:00<00:00, 168.22it/s, train_loss=0.528]
100%|██████████| 1/1 [00:00<00:00, 468.17it/s, train_loss=0.506]
100%|██████████| 1/1 [00:00<00:00, 314.75it/s, train_loss=0.485]
100%|██████████| 1/1 [00:00<00:00, 294.81it/s, train_loss=0.465]
100%|██████████| 1/1 [00:00<00:00, 346.06it/s, train_loss=0.447]


You can find a detailed description [here](https://colab.research.google.com/drive/1fdqTk4CTXNcrcSVFvaOKzRfLmj4fJfwa?usp=sharing#scrollTo=zRHI45kUzKul)