pip install .
or
pip install -e .
TODO
Replace the content in the __init__
method.
from torchT import TemplateModel
class Model(TemplateModel):
def __init__(self, args):
# ============== neccessary ===============
self.writer = None
self.step = 0
self.epoch = 0
self.best_error = float('Inf')
self.model = None
self.optimizer = None
self.criterion = None
self.metric = None
self.train_loader = None
self.test_loader = None
self.device = None
self.ckpt_dir = None
self.log_per_step = None
# ============== not neccessary ===============
self.train_logger = None
self.eval_logger = None
self.args = None
# call it to check all members have been intiated
self.check_init()
And then all you need is to write a little training loop like this:
model = Model()
for epoch in range(n_epochs):
model.train()
if (epoch + 1) % eval_per_epoch == 0:
model.eval()
print('Done!!!')
Resume training is very convenient, just need to load the saved model.
model = Model()
if model_path:
model.load_state(model_path)
for i in range(n_epochs):
model.train()
if model.epoch % eval_per_epoch == 0:
model.eval()
Write your own train_loss()
andeval_error()
member methods.
Default methods:
def train_loss(self, batch):
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
pred = self.model(x)
loss = self.criterion(pred, y)
return loss, None
def eval_error(self):
xs, ys, preds = [], [], []
for batch in self.test_loader:
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
pred = self.model(x)
xs.append(x.cpu())
ys.append(y.cpu())
preds.append(pred.cpu())
xs = torch.cat(xs, dim=0)
ys = torch.cat(ys, dim=0)
preds = torch.cat(preds, dim=0)
error = self.metric(preds, ys)
return error, None
How to write your own methods:
train_loss
recieves abatch
from dataloader as input, returnloss
andothers
which can be used as input fortrain_logger
eval_error
returnerror
of the whole test dataset andothers
which can be used as input foreval_logger
You can refer to the source code for more details.
-
LeNet: Train a LeNet to classify MNIST handwriting digits.
-
Training procedure:
...... epoch 1 step 3400 loss 0.0434 epoch 1 step 3500 loss 0.0331 epoch 1 step 3600 loss 0.00188 epoch 1 step 3700 loss 0.00341 save model at ../models\best.pth.tar save model at ../models\1.pth.tar epoch 1 error 0.0237 epoch 2 step 3800 loss 0.0201 epoch 2 step 3900 loss 0.00523 epoch 2 step 4000 loss 0.0236 ......
-
Use tensorboard to visualize the result:
tensorboard --logdir example/LeNet/log
train_loss eval_error -
Resume
load model from checkpoint/9.pth.tar epoch 10 step 33800 loss 0.000128 epoch 10 step 33900 loss 6.64e-06 epoch 10 step 34000 loss 0.000613 epoch 10 step 34100 loss 2.41e-05 ......
-