# Training a ResNet model for Channel Estimation
Now that we have a channel estimation dataset, we can train a model to receive the channel estimates at the pilot locations and predict the channel estimates for the whole channel matrix.

We use [Fireball](https://github.com/InterDigitalInc/Fireball) deep-learning package to train our models. The following diagram shows the neural network structure used for the channel estimation model.

![NN Structure](NN.png)

So let's get started by importing some Fireball modules.

In [1]:
import numpy as np
import scipy.io
import time

from fireball import Model, Block
from fireball.datasets.base import BaseDSet

## Loading dataset
We first load our dataset files generated in the [previous step](MLChestDataGen.ipynb). Then we create three [dataset objects](https://interdigitalinc.github.io/Fireball/html/source/datasets.html#fireball.datasets.base.BaseDSet) for training, validation, and test. 

In [2]:
# Read the data
trainSample, trainlabels = np.load("ChestTrain.npy")
validSample, validlabels = np.load("ChestValid.npy")
testSample, testlabels = np.load("ChestTest.npy")

trainDs = BaseDSet('train', None, trainSample, trainlabels, 32)
validDs = BaseDSet('valid', None, validSample, validlabels, 32)
testDs = BaseDSet('test', None, testSample, testlabels, 32)
BaseDSet.printDsInfo(trainDs,testDs,validDs)


BaseDSet Dataset Info:
    Number of Training Samples ..................... 16000
    Number of Test Samples ......................... 2400
    Number of Validation Samples ................... 2400
    Sample Shape ................................... (14, 612, 2)


## Creating the model
Now we can create a Fireball model object that will be used for training. We first define 2 ``ResNet`` blocks ``RES1`` and ``RES2`` (The gray boxes in the above diagram). We then define all the layers in the ``layersInfo`` string.

In [3]:
# Defining the ResNet Blocks
blocks = [
    Block('RES1|k_kernel_ixi,o_outSizes_i*3,s_stride_ixi_1|' +             # RES1
          'add|' +
          'CONV_K1_S%s_O%o0_Pv,BN:ReLU,CONV_K%k_S1_O%o1_Ps,BN:ReLU,CONV_K1_S1_O%o2,BN;ID'),

    Block('RES2|k_kernel_ixi,o_outSizes_i*3,s_stride_ixi_1|' +             # RES2
          'add|' +
          'CONV_K1_S%s_O%o0_Pv,BN:ReLU,CONV_K%k_S1_O%o1_Ps,BN:ReLU,CONV_K1_S1_O%o2,BN;'+
          'CONV_K1_S%s_O%o2_Pv,BN') ]

# Defining the Layers of the neural network:
layersInfo = ("TENSOR_S14/612/2;" +             # Input layer
              "RES2_K11x9_O16/16/64:ReLU," +    # RES2 resnet block (9x11 kernel), ReLU activation function
              "RES1_K7x3_O16/16/64:ReLU;" +     # RES1 resnet block (3x7 kernel), ReLU activation function
              "CONV_K3_O2_Ps::L2R;" +           # Convolutional layer (3x3 kernel), L2 Regularization
              "REG_S14/612/2")                  # Output layer

# Create the model for training:
model = Model(name="ChanEst", layersInfo = layersInfo, blocks = blocks, 
              trainDs=trainDs, validationDs=validDs,
              batchSize=128, 
              numEpochs=400,
              learningRate=(0.002,0.00001),  # Learning rate starts at 0.002 decaying exponentially to 0.00001
              regFactor=0.0,
              dropOutKeep=1,
              optimizer="Adam",
              gpus="0")

model.printLayersInfo()                                           # Print layers
print("Model Complexity:",'{:,} flops'.format(model.getFlops()))  # Get Model Complexity



Scope            InShape       Comments                 OutShape      Activ.   Post Act.        # of Params
---------------  ------------  -----------------------  ------------  -------  ---------------  -----------
IN_TENSOR        14 612 2      Tensor Shape: 14x612x2   14 612 2      None                      0          
S1_L1_RES2       14 612 2      2 Paths, 8 layers        14 612 64     ReLU                      27,328     
S1_L2_RES1       14 612 64     2 Paths, 7 layers        14 612 64     ReLU                      7,904      
S2_L1_CONV       14 612 64     KSP: 3 1 s               14 612 2      None     L2               1,154      
OUT_REG          14 612 2                               14 612 2      None                      0          
---------------------------------------------------------------------------------------------------------
                                                                  Total Number of parameters: 36,386     
Model Complexity: 608,672,256 f

## Training the model
**Note**: The following cell can take several hours to complete. A trained model is included in the ``Models`` directory, so you can skip the following cells and proceed to the [next step](MLChestEvaluate.ipynb).

In [4]:
# Start Training:
model.initSession()
model.train()
model.save("Models/ChEstResNet.fbm")   # Save the trained model.

+--------+---------+---------------+-----------+-------------------+
| Epoch  | Batch   | Learning Rate | Loss      | Valid/Test MSE    |
+--------+---------+---------------+-----------+-------------------+
| 1      | 124     | 0.00200000009 | 0.0638999 | 0.010    N/A      |
| 2      | 249     | 0.00200000009 | 0.0086457 | 0.010    N/A      |
| 3      | 374     | 0.00200000009 | 0.0057803 | 0.007    N/A      |
| 4      | 499     | 0.00190000003 | 0.0047276 | 0.005    N/A      |
| 5      | 624     | 0.00190000003 | 0.003938  | 0.007    N/A      |
| 6      | 749     | 0.00190000003 | 0.0035286 | 0.008    N/A      |
| 7      | 874     | 0.00190000003 | 0.0034247 | 0.007    N/A      |
| 8      | 999     | 0.00180500001 | 0.0030548 | 0.008    N/A      |
| 9      | 1124    | 0.00180500001 | 0.0029008 | 0.012    N/A      |
| 10     | 1249    | 0.00180500001 | 0.0027836 | 0.004    N/A      |
| 11     | 1374    | 0.00180500001 | 0.0028083 | 0.010    N/A      |
| 12     | 1499    | 0.00171474996

## Evaluating the model

In [5]:
# Evaluate the model using the test dataset
testDs.evaluateModel(model)


  Processed 2400 Sample. (Time: 2.12 Sec.)                              

NMSE: 0.326738
MSE:  0.001526
RMSE: 0.033944
MAE:  0.026846


{'mse': 0.0015264276,
 'rmse': 0.033943895,
 'mae': 0.026845945,
 'nmse': 0.32673806058297666,
 'gMse': 0.0015264275197926392,
 'gRmse': 0.03906952162226509,
 'gMae': 0.026845943238406472,
 'csvItems': ['mse', 'rmse', 'mae', 'bestMSE', 'bestEpoch', 'trainTime'],
 'bestMSE': 0.001377510605379939,
 'bestEpoch': 370,
 'trainTime': 18354.61710548401}