# Training the model
This notebook trains the deep neural network for multi-layer channel estimation.  
An **already-trained model** is provided in the `Models` directory. A comprehensive hyperparameter search was performed to obtain this model. You may use this notebook to train your own model or proceed directly to evaluation using the included model.

In [1]:
import numpy as np
import torch
from torch.optim.lr_scheduler import ExponentialLR

import os
import matplotlib.pyplot as plt

from ChEstNet import ChEstNet, ChEstDataset

In [2]:
# Create the datasets:
dataPath = "/data/datasets/SelfRefine"      # Replace with the location of your data files
batchSize = 64                             
trainDS = ChEstDataset( os.path.join(dataPath,"Train.npy"), batchSize )
validDS = ChEstDataset( os.path.join(dataPath,"Valid.npy"), batchSize )
testDS = ChEstDataset( os.path.join(dataPath,"Test.npy"), batchSize )
                        

In [3]:
modelFileName = "Models/Trained.pth"        # The model file name after the training
numEpochs = 100
lrStart, lrEnd = 0.002, 0.00002             # Learning rate exponentially decaying from 'lrStart' to 'lrEnd'

device = f"cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = ChEstNet(device)                    # Create the model object

optimizer = torch.optim.Adam(model.parameters(), lr=lrStart)
lrScheduler = ExponentialLR(optimizer, np.exp(np.log(lrEnd/lrStart)/(numEpochs-1)))
lossFunction = torch.nn.MSELoss()


In [4]:
# Main Training Loop:
lowestLoss, bestEpoch = None, None
validLoss = None
print("Epoch   Learning Rate   Training Loss   Validation Loss")
print("-----   -------------   -------------   ---------------")
for epoch in range(numEpochs):
    curLr = lrScheduler.get_last_lr()[0]
    print(f" {epoch+1:-4d}     {curLr:-10f}      ", end="")
        
    # Train one epoch
    lossMin, lossMean, lossMax = model.trainEpoch(trainDS, lossFunction, optimizer)
    print(f"{lossMean:-10f}      ", end="")
    
    validLoss = model.evaluate(validDS, lossFunction)
    if lowestLoss is None:
        lowestLoss, bestEpoch = validLoss, epoch+1
        model.saveParams(modelFileName)     # Save the best model so far 
        print(f"{validLoss:-10f}   ")
    elif validLoss<lowestLoss:
        lowestLoss, bestEpoch = validLoss, epoch+1
        model.saveParams(modelFileName)     # Save the best model so far 
        print(f"{validLoss:-10f} * ")
    else:
        print(f"{validLoss:-10f}   ")
    lrScheduler.step()

model.loadParams(modelFileName)             # Load the best model
validLoss = lowestLoss


Epoch   Learning Rate   Training Loss   Validation Loss
-----   -------------   -------------   ---------------
    1       0.002000        0.592596        0.350749   
    2       0.001909        0.286578        0.293276 * 
    3       0.001822        0.229406        0.181759 * 
    4       0.001739        0.201478        0.238252   
    5       0.001660        0.185287        0.170843 * 
    6       0.001585        0.169945        0.162267 * 
    7       0.001513        0.161321        0.140687 * 
    8       0.001444        0.153775        0.134481 * 
    9       0.001379        0.147432        0.120986 * 
   10       0.001316        0.141432        0.117913 * 
   11       0.001256        0.137113        0.124352   
   12       0.001199        0.132559        0.126296   
   13       0.001144        0.129277        0.114175 * 
   14       0.001092        0.126393        0.104231 * 
   15       0.001043        0.122861        0.099872 * 
   16       0.000995        0.119728        0.10