## Model & Training

Model and training related cheat sheet.

### 1. Torch model
The functions / classes introduced in this cheatsheet are model related. Hence, we first build a simple torch model.

In [10]:
import torch
import torch.nn as nn

class WzzModel(nn.Module):
    def __init__(self,config):
        super(WzzModel, self).__init__()
        self.config = config
        self.input_layer = nn.Linear(self.config['input_size'],self.config['hidden_size'], bias=False)
        self.output_layer = nn.Linear(self.config['hidden_size'],self.config['output_size'])
    def forward(self, feat):
        output = self.input_layer(feat)
        output = self.output_layer(output)
        return output
        
config = dict(input_size=3, hidden_size=5, output_size=1)
model = WzzModel(config)

feat = torch.rand(10,3)
output = model(feat)
output.shape

torch.Size([10, 1])

### 2. clip_grad_norm_()

Clip the gradient in case of gradient exploding.

In [11]:
import torch.nn as nn

# loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1)
# optimizer.step()

tensor(0.)

### 3. torch.save

Save the parameters of well-trained model into the local file. 
- `model.state_dict()`: Only saving the parameters.
- Remove `model.state_dict()` to save the whole model but consumes more storage.

In [12]:
import torch

model_path = "./model.pkl"
torch.save(model.state_dict(), model_path)

### 4. load_state_dic() & torch.load()

Load trained parameters. 
- `torch.load()`: Load the parameters / model from local file.
- `model.load_state_dict()` assign parameters to the model.
- Remark: make sure what is saved in the local file.

In [13]:
import torch

model_path = "./model.pkl"
params = torch.load(model_path, map_location='cpu')
model.load_state_dict(params, strict=True)

<All keys matched successfully>

In [None]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-5)
parser.add_argument('--with-pretrained', dest="pretrained", action='store_true')
parser.add_argument('--no-pretrained', dest="pretrained", action='store_false')
parser.set_defaults(pretrained=False)
args = parser.parse_args()