### Training and Validation Loops

**Date:** 30/10/2021  
**Author:** Murad Popattia

Code block for explaining use of asteriks:

```
>>> mylist = [1,2,3]
>>> foo(*mylist)
x=1
y=2
z=3

>>> mydict = {'x':1,'y':2,'z':3}
>>> foo(**mydict)
x=1
y=2
z=3

>>> mytuple = (1, 2, 3)
>>> foo(*mytuple)
x=1
y=2
z=3
```

In [12]:
# sample training step
def train_one_step(model, data, optimizer):
    optimizer.zero_grad() # we do this before every step
    
    # sending data over the gpu
    for k,v in data.items:
        data[k] = v.to("cuda")
        
    # can also use this if model(x,y) the argument are same as the keys in the dictionary
    # loss = model(**data)
    
    loss = model(x=data["x"], y = data["y"])
    loss.backward()
    optimizer.step()
    
    return loss 

# sample validation step
def valid_one_step(model, data, optimizer):
    # sending data over the gpu
    for k,v in data.items:
        data[k] = v.to("cuda")    
    loss = model(x=data["x"], y = data["y"])
    
    return loss 

model.train() tells your model that you are training the model. So effectively layers like dropout, batchnorm etc. which behave different on the train and test procedures know what is going on and hence can behave accordingly.

You can call either model.eval() or model.train(mode=False) to tell that you are testing. 

In [11]:
def train_one_epoch(model, data_loader, optimizer):
    
    # switch model to train mode
    model.train()
    total_loss = 0
    
    for idx, data in enumerate(data_loader):
        loss = train_one_step(model, data, optimizer)
        # in case after every epoch we want to step the scheduler we can also do
        scheduler.step()
        total_loss += loss
    return total_loss

def valid_one_epoch(model, data_loader, optimizer):
    
    # switch model to train mode
    model.eval()
    total_loss = 0
    
    for idx, data in enumerate(data_loader):
        with torch.no_grad():
            loss = valid_one_step(model, data, optimizer)
            # in case after every epoch we want to step the scheduler we can also do
            total_loss += loss
    return total_loss