-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
37 lines (26 loc) · 1.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
class VanillaTrainer():
def __init__(self, args, train_loader, device, optimizer, scheduler, criterion):
self.args = args
self.train_loader = train_loader
self.device = device
self.optimizer = optimizer
self.scheduler = scheduler
self.criterion = criterion
def train(self, model):
model.train()
total_train_loss = 0
for data in self.train_loader:
poses = data['pose'].to(self.device)
labels = data['label'][0].to(self.device)
self.optimizer.zero_grad()
features = model(poses)
loss = self.criterion(features, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.clip)
self.optimizer.step()
total_train_loss += loss.item()
self.scheduler.step()
train_loss = np.round(total_train_loss, 4)
print("TRAIN LOSS: " + str(train_loss))
return train_loss