In [1]:
%load_ext autoreload
%autoreload 2 

In [3]:
import torch
import torch.optim as optim
from torch.utils import data
import numpy as np
from feedforwardnet import FNet
from data import dataGen
from tqdm import tqdm
from jacobian import JacobianReg
import time
import crocoddyl
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
# Tensor data for training
positions, costs = dataGen(size = 30000)

In [5]:
# Torch dataloader
dataset = torch.utils.data.TensorDataset(positions,costs)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 10000, num_workers=2) 

In [6]:
# Generate a Neural Net
net = FNet(input_features = positions.shape[1], 
             output_features = costs.shape[1],
             n_hiddenunits = 256)

In [7]:
# Set the net to training mode
net = net.float()
net.train()

FNet(
  (fc1): Linear(in_features=3, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=3, bias=True)
  (fc3): Linear(in_features=3, out_features=1, bias=True)
)

In [8]:
# Initialize loss and optimizer
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = optim.Adam(net.parameters(), lr= 1e-3, weight_decay = 0.1)   

# Jacobian regularization
reg = JacobianReg() 
lambda_JR = 0.01 


In [None]:
t0 = time.time()    
# Training    
for epoch in tqdm(range(12000)):        
    for i, (data, target) in enumerate(dataloader):   
        data.requires_grad=True
        # Forward pass
        outputs = net(data)
        loss1 = criterion(outputs, target)
        loss2 = reg(data, outputs)                      # Jacobian regularization
        loss = loss1 + lambda_JR*loss2                  # full loss

        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Loss: {:.4f}"
                   .format(epoch+1, num_epochs,loss.item()))

print('Training lasted = %.0f seconds' % (time.time()-t0))        






  0%|          | 56/12000 [00:14<49:55,  3.99it/s]

In [9]:
torch.save(net, 'fnet.pth')