In [1]:
%load_ext autoreload
%autoreload 2 

In [2]:
import torch
import numpy as np
import torch.optim as optim
from torch.utils import data
from sqnet import SQNet
from data import dataGen
from tqdm import tqdm
from jacobian import JacobianReg
import time
import crocoddyl


In [3]:
# Tensor data for training
positions, costs, grad1, _ = dataGen(size = 1000)

In [4]:
# Torch dataloader
dataset = torch.utils.data.TensorDataset(positions,costs, grad1)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 1000) 

In [5]:
# Generate a Neural Net
net = SQNet(input_features = positions.shape[1], 
             output_features = costs.shape[1],
             n_hiddenunits = 16)

In [6]:
# Initialize loss and optimizer
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.95)
lam1 = 0.01


In [7]:
for epoch in tqdm(range(5000)):
    epoch_mse0 = 0.0
    epoch_mse1 = 0.0
    for i, (input, target0, target1) in enumerate(dataloader):  
        
        net.eval()

        output0 = net(input)               
        output1 = net.jacobian(input)

        net.train()
        mse0 = criterion(output0, target0)
        mse1 = criterion(output1, target1)
        

        epoch_mse0 += mse0.item() * len(input)
        epoch_mse1 += mse1.item() * len(input)
        
        loss = mse0 + lam1 * mse1 

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    scheduler.step()
 
    epoch_mse0 /= 3000
    epoch_mse1 /= 3000
    
    epoch_loss = epoch_mse0 + lam1*epoch_mse1 
    print('epoch', epoch,
      'lr', '{:.7f}'.format(optimizer.param_groups[0]['lr']),
      'mse0', '{:.5f}'.format(epoch_mse0),
      'mse1', '{:.5f}'.format(epoch_mse1),
      'Total Loss', '{:.5f}'.format(epoch_loss))
  




  0%|          | 1/5000 [00:00<40:37,  2.05it/s]

epoch 0 lr 0.0010000 mse0 318412.29167 mse1 335365.77083 Total Loss 321765.94938


  0%|          | 2/5000 [00:00<40:24,  2.06it/s]

epoch 1 lr 0.0010000 mse0 318316.16667 mse1 335351.77083 Total Loss 321669.68438


  0%|          | 3/5000 [00:01<40:15,  2.07it/s]

epoch 2 lr 0.0010000 mse0 318218.10417 mse1 335337.50000 Total Loss 321571.47917


  0%|          | 4/5000 [00:01<40:09,  2.07it/s]

epoch 3 lr 0.0010000 mse0 318118.00000 mse1 335322.91667 Total Loss 321471.22917


  0%|          | 5/5000 [00:02<40:56,  2.03it/s]

epoch 4 lr 0.0010000 mse0 318015.72917 mse1 335307.91667 Total Loss 321368.80833


  0%|          | 6/5000 [00:02<39:42,  2.10it/s]

epoch 5 lr 0.0010000 mse0 317911.39583 mse1 335292.64583 Total Loss 321264.32229


  0%|          | 7/5000 [00:03<39:34,  2.10it/s]

epoch 6 lr 0.0010000 mse0 317804.77083 mse1 335277.00000 Total Loss 321157.54083


  0%|          | 8/5000 [00:03<38:30,  2.16it/s]

epoch 7 lr 0.0010000 mse0 317695.83333 mse1 335261.02083 Total Loss 321048.44354


  0%|          | 9/5000 [00:04<39:40,  2.10it/s]

epoch 8 lr 0.0010000 mse0 317584.50000 mse1 335244.66667 Total Loss 320936.94667


  0%|          | 10/5000 [00:04<39:29,  2.11it/s]

epoch 9 lr 0.0010000 mse0 317470.66667 mse1 335228.02083 Total Loss 320822.94688


  0%|          | 11/5000 [00:05<39:57,  2.08it/s]

epoch 10 lr 0.0010000 mse0 317354.20833 mse1 335210.91667 Total Loss 320706.31750


  0%|          | 12/5000 [00:05<38:33,  2.16it/s]

epoch 11 lr 0.0010000 mse0 317235.00000 mse1 335193.29167 Total Loss 320586.93292


  0%|          | 13/5000 [00:06<38:17,  2.17it/s]

epoch 12 lr 0.0010000 mse0 317113.02083 mse1 335175.22917 Total Loss 320464.77313


  0%|          | 14/5000 [00:06<38:28,  2.16it/s]

epoch 13 lr 0.0010000 mse0 316988.12500 mse1 335156.60417 Total Loss 320339.69104


  0%|          | 15/5000 [00:07<39:36,  2.10it/s]

epoch 14 lr 0.0010000 mse0 316860.10417 mse1 335137.56250 Total Loss 320211.47979


  0%|          | 16/5000 [00:07<38:32,  2.16it/s]

epoch 15 lr 0.0010000 mse0 316728.81250 mse1 335118.02083 Total Loss 320079.99271


  0%|          | 17/5000 [00:07<37:31,  2.21it/s]

epoch 16 lr 0.0010000 mse0 316594.27083 mse1 335097.89583 Total Loss 319945.24979


  0%|          | 18/5000 [00:08<41:29,  2.00it/s]

epoch 17 lr 0.0010000 mse0 316456.27083 mse1 335077.27083 Total Loss 319807.04354


  0%|          | 19/5000 [00:09<41:16,  2.01it/s]

epoch 18 lr 0.0010000 mse0 316314.66667 mse1 335056.04167 Total Loss 319665.22708


  0%|          | 20/5000 [00:09<40:59,  2.02it/s]

epoch 19 lr 0.0010000 mse0 316169.37500 mse1 335034.37500 Total Loss 319519.71875


  0%|          | 21/5000 [00:10<40:25,  2.05it/s]

epoch 20 lr 0.0010000 mse0 316020.10417 mse1 335012.04167 Total Loss 319370.22458


  0%|          | 22/5000 [00:10<39:51,  2.08it/s]

epoch 21 lr 0.0010000 mse0 315866.93750 mse1 334989.29167 Total Loss 319216.83042


  0%|          | 23/5000 [00:10<38:44,  2.14it/s]

epoch 22 lr 0.0010000 mse0 315709.62500 mse1 334965.87500 Total Loss 319059.28375


  0%|          | 24/5000 [00:11<38:33,  2.15it/s]

epoch 23 lr 0.0010000 mse0 315548.00000 mse1 334941.79167 Total Loss 318897.41792


  0%|          | 25/5000 [00:11<39:48,  2.08it/s]

epoch 24 lr 0.0010000 mse0 315381.97917 mse1 334917.31250 Total Loss 318731.15229


  1%|          | 26/5000 [00:12<39:40,  2.09it/s]

epoch 25 lr 0.0010000 mse0 315211.33333 mse1 334892.16667 Total Loss 318560.25500


  1%|          | 27/5000 [00:12<39:20,  2.11it/s]

epoch 26 lr 0.0010000 mse0 315035.95833 mse1 334866.37500 Total Loss 318384.62208


  1%|          | 28/5000 [00:13<38:56,  2.13it/s]

epoch 27 lr 0.0010000 mse0 314855.64583 mse1 334840.00000 Total Loss 318204.04583


  1%|          | 29/5000 [00:13<38:55,  2.13it/s]

epoch 28 lr 0.0010000 mse0 314670.31250 mse1 334813.10417 Total Loss 318018.44354


  1%|          | 30/5000 [00:14<38:59,  2.12it/s]

epoch 29 lr 0.0010000 mse0 314479.77083 mse1 334785.58333 Total Loss 317827.62667


  1%|          | 31/5000 [00:14<38:48,  2.13it/s]

epoch 30 lr 0.0010000 mse0 314283.81250 mse1 334757.39583 Total Loss 317631.38646


  1%|          | 32/5000 [00:15<39:09,  2.11it/s]

epoch 31 lr 0.0010000 mse0 314082.29167 mse1 334728.66667 Total Loss 317429.57833


  1%|          | 33/5000 [00:15<39:30,  2.10it/s]

epoch 32 lr 0.0010000 mse0 313875.12500 mse1 334699.35417 Total Loss 317222.11854


  1%|          | 34/5000 [00:16<38:44,  2.14it/s]

epoch 33 lr 0.0010000 mse0 313662.08333 mse1 334669.37500 Total Loss 317008.77708


  1%|          | 35/5000 [00:16<38:36,  2.14it/s]

epoch 34 lr 0.0010000 mse0 313443.06250 mse1 334638.91667 Total Loss 316789.45167


  1%|          | 36/5000 [00:17<38:31,  2.15it/s]

epoch 35 lr 0.0010000 mse0 313217.83333 mse1 334607.87500 Total Loss 316563.91208


  1%|          | 37/5000 [00:17<38:24,  2.15it/s]

epoch 36 lr 0.0010000 mse0 312986.27083 mse1 334576.16667 Total Loss 316332.03250


  1%|          | 38/5000 [00:18<38:58,  2.12it/s]

epoch 37 lr 0.0010000 mse0 312748.22917 mse1 334543.97917 Total Loss 316093.66896


  1%|          | 39/5000 [00:18<37:08,  2.23it/s]

epoch 38 lr 0.0010000 mse0 312503.52083 mse1 334511.33333 Total Loss 315848.63417


  1%|          | 40/5000 [00:18<36:08,  2.29it/s]

epoch 39 lr 0.0010000 mse0 312252.08333 mse1 334478.16667 Total Loss 315596.86500


  1%|          | 41/5000 [00:19<37:00,  2.23it/s]

epoch 40 lr 0.0010000 mse0 311993.66667 mse1 334444.47917 Total Loss 315338.11146


  1%|          | 42/5000 [00:19<36:54,  2.24it/s]

epoch 41 lr 0.0010000 mse0 311728.22917 mse1 334410.37500 Total Loss 315072.33292


  1%|          | 43/5000 [00:20<37:53,  2.18it/s]

epoch 42 lr 0.0010000 mse0 311455.54167 mse1 334375.95833 Total Loss 314799.30125


  1%|          | 44/5000 [00:20<37:57,  2.18it/s]

epoch 43 lr 0.0010000 mse0 311175.50000 mse1 334341.14583 Total Loss 314518.91146


  1%|          | 45/5000 [00:21<38:18,  2.16it/s]

epoch 44 lr 0.0010000 mse0 310888.02083 mse1 334305.95833 Total Loss 314231.08042


  1%|          | 46/5000 [00:21<38:32,  2.14it/s]

epoch 45 lr 0.0010000 mse0 310592.91667 mse1 334270.58333 Total Loss 313935.62250


  1%|          | 47/5000 [00:22<38:52,  2.12it/s]

epoch 46 lr 0.0010000 mse0 310290.10417 mse1 334235.04167 Total Loss 313632.45458


  1%|          | 48/5000 [00:22<39:10,  2.11it/s]

epoch 47 lr 0.0010000 mse0 309979.43750 mse1 334199.18750 Total Loss 313321.42938


  1%|          | 49/5000 [00:23<39:19,  2.10it/s]

epoch 48 lr 0.0010000 mse0 309660.72917 mse1 334163.33333 Total Loss 313002.36250


  1%|          | 50/5000 [00:23<37:07,  2.22it/s]

epoch 49 lr 0.0010000 mse0 309334.04167 mse1 334127.37500 Total Loss 312675.31542


  1%|          | 51/5000 [00:23<38:15,  2.16it/s]

epoch 50 lr 0.0010000 mse0 308999.12500 mse1 334091.50000 Total Loss 312340.04000


  1%|          | 52/5000 [00:24<38:56,  2.12it/s]

epoch 51 lr 0.0010000 mse0 308655.93750 mse1 334055.60417 Total Loss 311996.49354


  1%|          | 53/5000 [00:24<38:51,  2.12it/s]

epoch 52 lr 0.0010000 mse0 308304.35417 mse1 334019.75000 Total Loss 311644.55167


  1%|          | 54/5000 [00:25<39:03,  2.11it/s]

epoch 53 lr 0.0010000 mse0 307944.25000 mse1 333984.12500 Total Loss 311284.09125


  1%|          | 55/5000 [00:25<39:51,  2.07it/s]

epoch 54 lr 0.0010000 mse0 307575.64583 mse1 333948.64583 Total Loss 310915.13229


  1%|          | 56/5000 [00:26<39:47,  2.07it/s]

epoch 55 lr 0.0010000 mse0 307198.33333 mse1 333913.43750 Total Loss 310537.46771


  1%|          | 57/5000 [00:26<38:31,  2.14it/s]

epoch 56 lr 0.0010000 mse0 306812.22917 mse1 333878.54167 Total Loss 310151.01458


  1%|          | 58/5000 [00:27<38:41,  2.13it/s]

epoch 57 lr 0.0010000 mse0 306417.29167 mse1 333843.91667 Total Loss 309755.73083


  1%|          | 59/5000 [00:27<39:40,  2.08it/s]

epoch 58 lr 0.0010000 mse0 306013.45833 mse1 333809.83333 Total Loss 309351.55667


  1%|          | 60/5000 [00:28<39:24,  2.09it/s]

epoch 59 lr 0.0010000 mse0 305600.66667 mse1 333776.12500 Total Loss 308938.42792


  1%|          | 61/5000 [00:28<38:56,  2.11it/s]

epoch 60 lr 0.0010000 mse0 305178.77083 mse1 333742.89583 Total Loss 308516.19979


  1%|          | 62/5000 [00:29<39:07,  2.10it/s]

epoch 61 lr 0.0010000 mse0 304747.77083 mse1 333710.25000 Total Loss 308084.87333


  1%|▏         | 63/5000 [00:29<38:43,  2.12it/s]

epoch 62 lr 0.0010000 mse0 304307.56250 mse1 333678.22917 Total Loss 307644.34479


  1%|▏         | 64/5000 [00:30<37:34,  2.19it/s]

epoch 63 lr 0.0010000 mse0 303858.08333 mse1 333646.85417 Total Loss 307194.55188


  1%|▏         | 65/5000 [00:30<37:41,  2.18it/s]

epoch 64 lr 0.0010000 mse0 303399.29167 mse1 333616.16667 Total Loss 306735.45333


  1%|▏         | 66/5000 [00:31<38:32,  2.13it/s]

epoch 65 lr 0.0010000 mse0 302931.25000 mse1 333586.14583 Total Loss 306267.11146


  1%|▏         | 67/5000 [00:31<38:20,  2.14it/s]

epoch 66 lr 0.0010000 mse0 302453.62500 mse1 333557.04167 Total Loss 305789.19542


  1%|▏         | 68/5000 [00:31<37:07,  2.21it/s]

epoch 67 lr 0.0010000 mse0 301966.58333 mse1 333528.77083 Total Loss 305301.87104


  1%|▏         | 69/5000 [00:32<35:23,  2.32it/s]

epoch 68 lr 0.0010000 mse0 301470.04167 mse1 333501.39583 Total Loss 304805.05563


  1%|▏         | 70/5000 [00:32<35:25,  2.32it/s]

epoch 69 lr 0.0010000 mse0 300964.02083 mse1 333474.95833 Total Loss 304298.77042


  1%|▏         | 71/5000 [00:33<36:40,  2.24it/s]

epoch 70 lr 0.0010000 mse0 300448.33333 mse1 333449.50000 Total Loss 303782.82833


  1%|▏         | 72/5000 [00:33<37:24,  2.20it/s]

epoch 71 lr 0.0010000 mse0 299923.06250 mse1 333425.14583 Total Loss 303257.31396


  1%|▏         | 73/5000 [00:34<37:30,  2.19it/s]

epoch 72 lr 0.0010000 mse0 299388.10417 mse1 333401.87500 Total Loss 302722.12292


  1%|▏         | 74/5000 [00:34<37:58,  2.16it/s]

epoch 73 lr 0.0010000 mse0 298843.47917 mse1 333379.83333 Total Loss 302177.27750


  2%|▏         | 75/5000 [00:35<38:31,  2.13it/s]

epoch 74 lr 0.0010000 mse0 298289.14583 mse1 333358.87500 Total Loss 301622.73458


  2%|▏         | 76/5000 [00:35<39:11,  2.09it/s]

epoch 75 lr 0.0010000 mse0 297725.12500 mse1 333339.39583 Total Loss 301058.51896


  2%|▏         | 77/5000 [00:36<39:05,  2.10it/s]

epoch 76 lr 0.0010000 mse0 297151.35417 mse1 333321.04167 Total Loss 300484.56458


  2%|▏         | 78/5000 [00:36<38:27,  2.13it/s]

epoch 77 lr 0.0010000 mse0 296567.83333 mse1 333304.22917 Total Loss 299900.87562


  2%|▏         | 79/5000 [00:37<39:04,  2.10it/s]

epoch 78 lr 0.0010000 mse0 295974.60417 mse1 333288.83333 Total Loss 299307.49250


  2%|▏         | 80/5000 [00:37<39:18,  2.09it/s]

epoch 79 lr 0.0010000 mse0 295371.60417 mse1 333274.93750 Total Loss 298704.35354


  2%|▏         | 81/5000 [00:38<39:32,  2.07it/s]

epoch 80 lr 0.0010000 mse0 294758.87500 mse1 333262.58333 Total Loss 298091.50083


  2%|▏         | 82/5000 [00:38<39:22,  2.08it/s]

epoch 81 lr 0.0010000 mse0 294136.37500 mse1 333251.95833 Total Loss 297468.89458


  2%|▏         | 83/5000 [00:38<39:41,  2.06it/s]

epoch 82 lr 0.0010000 mse0 293504.20833 mse1 333243.02083 Total Loss 296836.63854


  2%|▏         | 84/5000 [00:39<38:34,  2.12it/s]

epoch 83 lr 0.0010000 mse0 292862.29167 mse1 333235.89583 Total Loss 296194.65063


  2%|▏         | 85/5000 [00:39<39:04,  2.10it/s]

epoch 84 lr 0.0010000 mse0 292210.64583 mse1 333230.64583 Total Loss 295542.95229


  2%|▏         | 86/5000 [00:40<38:22,  2.13it/s]

epoch 85 lr 0.0010000 mse0 291549.37500 mse1 333227.25000 Total Loss 294881.64750


  2%|▏         | 87/5000 [00:40<38:34,  2.12it/s]

epoch 86 lr 0.0010000 mse0 290878.45833 mse1 333226.00000 Total Loss 294210.71833


  2%|▏         | 88/5000 [00:41<39:00,  2.10it/s]

epoch 87 lr 0.0010000 mse0 290197.97917 mse1 333226.62500 Total Loss 293530.24542


  2%|▏         | 89/5000 [00:41<39:09,  2.09it/s]

epoch 88 lr 0.0010000 mse0 289507.91667 mse1 333229.52083 Total Loss 292840.21188


  2%|▏         | 90/5000 [00:42<39:14,  2.09it/s]

epoch 89 lr 0.0010000 mse0 288808.31250 mse1 333234.72917 Total Loss 292140.65979


  2%|▏         | 91/5000 [00:42<38:53,  2.10it/s]

epoch 90 lr 0.0010000 mse0 288099.22917 mse1 333242.25000 Total Loss 291431.65167


  2%|▏         | 92/5000 [00:43<39:28,  2.07it/s]

epoch 91 lr 0.0010000 mse0 287380.70833 mse1 333252.04167 Total Loss 290713.22875


  2%|▏         | 93/5000 [00:43<38:24,  2.13it/s]

epoch 92 lr 0.0010000 mse0 286652.85417 mse1 333264.54167 Total Loss 289985.49958


  2%|▏         | 94/5000 [00:44<39:12,  2.09it/s]

epoch 93 lr 0.0010000 mse0 285915.66667 mse1 333279.43750 Total Loss 289248.46104


  2%|▏         | 95/5000 [00:44<39:18,  2.08it/s]

epoch 94 lr 0.0010000 mse0 285169.29167 mse1 333297.08333 Total Loss 288502.26250


  2%|▏         | 96/5000 [00:45<39:33,  2.07it/s]

epoch 95 lr 0.0010000 mse0 284413.64583 mse1 333317.50000 Total Loss 287746.82083


  2%|▏         | 97/5000 [00:45<39:41,  2.06it/s]

epoch 96 lr 0.0010000 mse0 283648.87500 mse1 333340.70833 Total Loss 286982.28208


  2%|▏         | 98/5000 [00:46<39:34,  2.06it/s]

epoch 97 lr 0.0010000 mse0 282875.14583 mse1 333366.89583 Total Loss 286208.81479


  2%|▏         | 99/5000 [00:46<39:48,  2.05it/s]

epoch 98 lr 0.0010000 mse0 282092.45833 mse1 333396.00000 Total Loss 285426.41833


  2%|▏         | 100/5000 [00:47<40:06,  2.04it/s]

epoch 99 lr 0.0010000 mse0 281300.83333 mse1 333428.22917 Total Loss 284635.11562


  2%|▏         | 101/5000 [00:47<39:41,  2.06it/s]

epoch 100 lr 0.0010000 mse0 280500.56250 mse1 333463.62500 Total Loss 283835.19875


  2%|▏         | 102/5000 [00:48<39:44,  2.05it/s]

epoch 101 lr 0.0010000 mse0 279691.50000 mse1 333502.20833 Total Loss 283026.52208


  2%|▏         | 103/5000 [00:48<39:42,  2.06it/s]

epoch 102 lr 0.0010000 mse0 278873.89583 mse1 333544.16667 Total Loss 282209.33750


  2%|▏         | 104/5000 [00:49<39:55,  2.04it/s]

epoch 103 lr 0.0010000 mse0 278047.83333 mse1 333589.39583 Total Loss 281383.72729


  2%|▏         | 105/5000 [00:49<39:44,  2.05it/s]

epoch 104 lr 0.0010000 mse0 277213.37500 mse1 333638.08333 Total Loss 280549.75583


  2%|▏         | 106/5000 [00:49<38:16,  2.13it/s]

epoch 105 lr 0.0010000 mse0 276370.68750 mse1 333690.29167 Total Loss 279707.59042


  2%|▏         | 107/5000 [00:50<38:36,  2.11it/s]

epoch 106 lr 0.0010000 mse0 275519.83333 mse1 333746.12500 Total Loss 278857.29458


  2%|▏         | 108/5000 [00:50<38:32,  2.12it/s]

epoch 107 lr 0.0010000 mse0 274660.93750 mse1 333805.45833 Total Loss 277998.99208


  2%|▏         | 109/5000 [00:51<38:28,  2.12it/s]

epoch 108 lr 0.0010000 mse0 273794.18750 mse1 333868.45833 Total Loss 277132.87208


  2%|▏         | 110/5000 [00:51<38:22,  2.12it/s]

epoch 109 lr 0.0010000 mse0 272919.77083 mse1 333935.27083 Total Loss 276259.12354


  2%|▏         | 111/5000 [00:52<38:13,  2.13it/s]

epoch 110 lr 0.0010000 mse0 272037.58333 mse1 334005.68750 Total Loss 275377.64021


  2%|▏         | 112/5000 [00:52<38:08,  2.14it/s]

epoch 111 lr 0.0010000 mse0 271148.00000 mse1 334079.91667 Total Loss 274488.79917


  2%|▏         | 113/5000 [00:53<38:33,  2.11it/s]

epoch 112 lr 0.0010000 mse0 270251.02083 mse1 334158.00000 Total Loss 273592.60083


  2%|▏         | 114/5000 [00:53<38:26,  2.12it/s]

epoch 113 lr 0.0010000 mse0 269346.91667 mse1 334239.93750 Total Loss 272689.31604


  2%|▏         | 115/5000 [00:54<36:57,  2.20it/s]

epoch 114 lr 0.0010000 mse0 268435.72917 mse1 334325.89583 Total Loss 271778.98813


  2%|▏         | 116/5000 [00:54<36:27,  2.23it/s]

epoch 115 lr 0.0010000 mse0 267517.72917 mse1 334415.58333 Total Loss 270861.88500


  2%|▏         | 117/5000 [00:55<37:31,  2.17it/s]

epoch 116 lr 0.0010000 mse0 266593.00000 mse1 334509.20833 Total Loss 269938.09208


  2%|▏         | 118/5000 [00:55<37:32,  2.17it/s]

epoch 117 lr 0.0010000 mse0 265661.62500 mse1 334606.83333 Total Loss 269007.69333


  2%|▏         | 119/5000 [00:56<37:42,  2.16it/s]

epoch 118 lr 0.0010000 mse0 264723.87500 mse1 334708.35417 Total Loss 268070.95854


  2%|▏         | 120/5000 [00:56<36:07,  2.25it/s]

epoch 119 lr 0.0010000 mse0 263779.91667 mse1 334813.89583 Total Loss 267128.05563


  2%|▏         | 121/5000 [00:56<36:33,  2.22it/s]

epoch 120 lr 0.0010000 mse0 262829.91667 mse1 334923.35417 Total Loss 266179.15021


  2%|▏         | 122/5000 [00:57<35:49,  2.27it/s]

epoch 121 lr 0.0010000 mse0 261874.02083 mse1 335036.77083 Total Loss 265224.38854


  2%|▏         | 123/5000 [00:57<36:14,  2.24it/s]

epoch 122 lr 0.0010000 mse0 260912.43750 mse1 335154.31250 Total Loss 264263.98063


  2%|▏         | 124/5000 [00:58<36:38,  2.22it/s]

epoch 123 lr 0.0010000 mse0 259945.35417 mse1 335275.87500 Total Loss 263298.11292


  2%|▎         | 125/5000 [00:58<36:59,  2.20it/s]

epoch 124 lr 0.0010000 mse0 258972.87500 mse1 335401.45833 Total Loss 262326.88958


  3%|▎         | 126/5000 [00:59<37:27,  2.17it/s]

epoch 125 lr 0.0010000 mse0 257995.29167 mse1 335531.25000 Total Loss 261350.60417


  3%|▎         | 127/5000 [00:59<37:29,  2.17it/s]

epoch 126 lr 0.0010000 mse0 257012.68750 mse1 335665.29167 Total Loss 260369.34042


  3%|▎         | 128/5000 [01:00<37:33,  2.16it/s]

epoch 127 lr 0.0010000 mse0 256025.35417 mse1 335803.52083 Total Loss 259383.38937


  3%|▎         | 129/5000 [01:00<37:54,  2.14it/s]

epoch 128 lr 0.0010000 mse0 255033.35417 mse1 335946.00000 Total Loss 258392.81417


  3%|▎         | 130/5000 [01:01<37:57,  2.14it/s]

epoch 129 lr 0.0010000 mse0 254036.97917 mse1 336092.97917 Total Loss 257397.90896


  3%|▎         | 131/5000 [01:01<37:50,  2.14it/s]

epoch 130 lr 0.0010000 mse0 253036.43750 mse1 336244.37500 Total Loss 256398.88125


  3%|▎         | 132/5000 [01:01<37:34,  2.16it/s]

epoch 131 lr 0.0010000 mse0 252031.87500 mse1 336400.18750 Total Loss 255395.87687


  3%|▎         | 133/5000 [01:02<37:55,  2.14it/s]

epoch 132 lr 0.0010000 mse0 251023.47917 mse1 336560.60417 Total Loss 254389.08521


  3%|▎         | 134/5000 [01:02<37:48,  2.15it/s]

epoch 133 lr 0.0010000 mse0 250011.50000 mse1 336725.64583 Total Loss 253378.75646


  3%|▎         | 135/5000 [01:03<37:41,  2.15it/s]

epoch 134 lr 0.0010000 mse0 248996.00000 mse1 336895.27083 Total Loss 252364.95271


  3%|▎         | 136/5000 [01:03<37:44,  2.15it/s]

epoch 135 lr 0.0010000 mse0 247977.37500 mse1 337069.54167 Total Loss 251348.07042


  3%|▎         | 137/5000 [01:04<37:37,  2.15it/s]

epoch 136 lr 0.0010000 mse0 246955.72917 mse1 337248.68750 Total Loss 250328.21604


  3%|▎         | 138/5000 [01:04<37:39,  2.15it/s]

epoch 137 lr 0.0010000 mse0 245931.14583 mse1 337432.60417 Total Loss 249305.47188


  3%|▎         | 139/5000 [01:05<37:36,  2.15it/s]

epoch 138 lr 0.0010000 mse0 244904.06250 mse1 337621.41667 Total Loss 248280.27667


  3%|▎         | 140/5000 [01:05<37:29,  2.16it/s]

epoch 139 lr 0.0010000 mse0 243874.47917 mse1 337815.08333 Total Loss 247252.63000


  3%|▎         | 141/5000 [01:06<37:28,  2.16it/s]

epoch 140 lr 0.0010000 mse0 242842.68750 mse1 338013.77083 Total Loss 246222.82521


  3%|▎         | 142/5000 [01:06<37:16,  2.17it/s]

epoch 141 lr 0.0010000 mse0 241808.85417 mse1 338217.47917 Total Loss 245191.02896


  3%|▎         | 143/5000 [01:07<37:33,  2.16it/s]

epoch 142 lr 0.0010000 mse0 240773.22917 mse1 338426.35417 Total Loss 244157.49271


  3%|▎         | 144/5000 [01:07<37:39,  2.15it/s]

epoch 143 lr 0.0010000 mse0 239735.87500 mse1 338640.39583 Total Loss 243122.27896


  3%|▎         | 145/5000 [01:08<37:41,  2.15it/s]

epoch 144 lr 0.0010000 mse0 238697.12500 mse1 338859.60417 Total Loss 242085.72104


  3%|▎         | 146/5000 [01:08<37:55,  2.13it/s]

epoch 145 lr 0.0010000 mse0 237657.18750 mse1 339084.20833 Total Loss 241048.02958


  3%|▎         | 147/5000 [01:08<37:55,  2.13it/s]

epoch 146 lr 0.0010000 mse0 236616.10417 mse1 339314.08333 Total Loss 240009.24500


  3%|▎         | 148/5000 [01:09<37:40,  2.15it/s]

epoch 147 lr 0.0010000 mse0 235574.25000 mse1 339549.47917 Total Loss 238969.74479


  3%|▎         | 149/5000 [01:09<37:49,  2.14it/s]

epoch 148 lr 0.0010000 mse0 234531.64583 mse1 339790.33333 Total Loss 237929.54917


  3%|▎         | 150/5000 [01:10<38:04,  2.12it/s]

epoch 149 lr 0.0010000 mse0 233488.62500 mse1 340036.89583 Total Loss 236888.99396


  3%|▎         | 151/5000 [01:10<37:43,  2.14it/s]

epoch 150 lr 0.0010000 mse0 232445.33333 mse1 340288.95833 Total Loss 235848.22292


  3%|▎         | 152/5000 [01:11<37:36,  2.15it/s]

epoch 151 lr 0.0010000 mse0 231401.87500 mse1 340546.81250 Total Loss 234807.34313


  3%|▎         | 153/5000 [01:11<37:29,  2.15it/s]

epoch 152 lr 0.0010000 mse0 230358.56250 mse1 340810.41667 Total Loss 233766.66667


  3%|▎         | 154/5000 [01:12<36:33,  2.21it/s]

epoch 153 lr 0.0010000 mse0 229315.52083 mse1 341079.87500 Total Loss 232726.31958


  3%|▎         | 155/5000 [01:12<36:40,  2.20it/s]

epoch 154 lr 0.0010000 mse0 228272.93750 mse1 341355.25000 Total Loss 231686.49000


  3%|▎         | 156/5000 [01:13<37:09,  2.17it/s]

epoch 155 lr 0.0010000 mse0 227230.95833 mse1 341636.58333 Total Loss 230647.32417


  3%|▎         | 157/5000 [01:13<37:11,  2.17it/s]

epoch 156 lr 0.0010000 mse0 226189.85417 mse1 341923.87500 Total Loss 229609.09292


  3%|▎         | 158/5000 [01:14<37:08,  2.17it/s]

epoch 157 lr 0.0010000 mse0 225149.75000 mse1 342217.27083 Total Loss 228571.92271


  3%|▎         | 159/5000 [01:14<37:04,  2.18it/s]

epoch 158 lr 0.0010000 mse0 224110.83333 mse1 342516.68750 Total Loss 227536.00021


  3%|▎         | 160/5000 [01:14<37:07,  2.17it/s]

epoch 159 lr 0.0010000 mse0 223073.22917 mse1 342822.18750 Total Loss 226501.45104


  3%|▎         | 161/5000 [01:15<36:09,  2.23it/s]

epoch 160 lr 0.0010000 mse0 222037.22917 mse1 343133.95833 Total Loss 225468.56875


  3%|▎         | 162/5000 [01:15<36:23,  2.22it/s]

epoch 161 lr 0.0010000 mse0 221002.91667 mse1 343451.79167 Total Loss 224437.43458


  3%|▎         | 163/5000 [01:16<36:47,  2.19it/s]

epoch 162 lr 0.0010000 mse0 219970.43750 mse1 343775.83333 Total Loss 223408.19583


  3%|▎         | 164/5000 [01:16<36:52,  2.19it/s]

epoch 163 lr 0.0010000 mse0 218940.02083 mse1 344106.04167 Total Loss 222381.08125


  3%|▎         | 165/5000 [01:17<37:02,  2.18it/s]

epoch 164 lr 0.0010000 mse0 217911.77083 mse1 344442.35417 Total Loss 221356.19438


  3%|▎         | 166/5000 [01:17<36:33,  2.20it/s]

epoch 165 lr 0.0010000 mse0 216885.95833 mse1 344784.89583 Total Loss 220333.80729


  3%|▎         | 167/5000 [01:18<36:46,  2.19it/s]

epoch 166 lr 0.0010000 mse0 215862.58333 mse1 345133.54167 Total Loss 219313.91875


  3%|▎         | 168/5000 [01:18<36:55,  2.18it/s]

epoch 167 lr 0.0010000 mse0 214841.91667 mse1 345488.22917 Total Loss 218296.79896


  3%|▎         | 169/5000 [01:19<37:10,  2.17it/s]

epoch 168 lr 0.0010000 mse0 213824.06250 mse1 345848.91667 Total Loss 217282.55167


  3%|▎         | 170/5000 [01:19<37:14,  2.16it/s]

epoch 169 lr 0.0010000 mse0 212809.16667 mse1 346215.64583 Total Loss 216271.32312


  3%|▎         | 171/5000 [01:19<37:11,  2.16it/s]

epoch 170 lr 0.0010000 mse0 211797.29167 mse1 346588.29167 Total Loss 215263.17458


  3%|▎         | 172/5000 [01:20<37:07,  2.17it/s]

epoch 171 lr 0.0010000 mse0 210788.75000 mse1 346966.77083 Total Loss 214258.41771


  3%|▎         | 173/5000 [01:20<37:19,  2.16it/s]

epoch 172 lr 0.0010000 mse0 209783.47917 mse1 347351.08333 Total Loss 213256.99000


  3%|▎         | 174/5000 [01:21<36:55,  2.18it/s]

epoch 173 lr 0.0010000 mse0 208781.75000 mse1 347740.95833 Total Loss 212259.15958


  4%|▎         | 175/5000 [01:21<35:59,  2.23it/s]

epoch 174 lr 0.0010000 mse0 207783.62500 mse1 348136.39583 Total Loss 211264.98896


  4%|▎         | 176/5000 [01:22<35:09,  2.29it/s]

epoch 175 lr 0.0010000 mse0 206789.25000 mse1 348537.18750 Total Loss 210274.62188


  4%|▎         | 177/5000 [01:22<35:41,  2.25it/s]

epoch 176 lr 0.0010000 mse0 205798.72917 mse1 348943.35417 Total Loss 209288.16271


  4%|▎         | 178/5000 [01:23<36:19,  2.21it/s]

epoch 177 lr 0.0010000 mse0 204812.18750 mse1 349354.54167 Total Loss 208305.73292


  4%|▎         | 179/5000 [01:23<36:37,  2.19it/s]

epoch 178 lr 0.0010000 mse0 203829.58333 mse1 349770.79167 Total Loss 207327.29125


  4%|▎         | 180/5000 [01:24<36:38,  2.19it/s]

epoch 179 lr 0.0010000 mse0 202851.22917 mse1 350191.75000 Total Loss 206353.14667


  4%|▎         | 181/5000 [01:24<36:52,  2.18it/s]

epoch 180 lr 0.0010000 mse0 201877.06250 mse1 350617.29167 Total Loss 205383.23542


  4%|▎         | 182/5000 [01:24<37:05,  2.16it/s]

epoch 181 lr 0.0010000 mse0 200907.25000 mse1 351047.16667 Total Loss 204417.72167


  4%|▎         | 183/5000 [01:25<37:03,  2.17it/s]

epoch 182 lr 0.0010000 mse0 199941.81250 mse1 351481.20833 Total Loss 203456.62458


  4%|▎         | 184/5000 [01:25<37:12,  2.16it/s]

epoch 183 lr 0.0010000 mse0 198980.87500 mse1 351919.20833 Total Loss 202500.06708


  4%|▎         | 185/5000 [01:26<37:39,  2.13it/s]

epoch 184 lr 0.0010000 mse0 198024.52083 mse1 352361.00000 Total Loss 201548.13083


  4%|▎         | 186/5000 [01:26<37:59,  2.11it/s]

epoch 185 lr 0.0010000 mse0 197072.72917 mse1 352805.95833 Total Loss 200600.78875


  4%|▎         | 187/5000 [01:27<39:08,  2.05it/s]

epoch 186 lr 0.0010000 mse0 196125.54167 mse1 353254.29167 Total Loss 199658.08458


  4%|▍         | 188/5000 [01:27<38:47,  2.07it/s]

epoch 187 lr 0.0010000 mse0 195183.14583 mse1 353705.29167 Total Loss 198720.19875


  4%|▍         | 189/5000 [01:28<38:21,  2.09it/s]

epoch 188 lr 0.0010000 mse0 194245.45833 mse1 354158.87500 Total Loss 197787.04708


  4%|▍         | 190/5000 [01:28<38:38,  2.07it/s]

epoch 189 lr 0.0010000 mse0 193312.54167 mse1 354614.58333 Total Loss 196858.68750


  4%|▍         | 191/5000 [01:29<39:04,  2.05it/s]

epoch 190 lr 0.0010000 mse0 192384.52083 mse1 355072.29167 Total Loss 195935.24375


  4%|▍         | 192/5000 [01:29<39:27,  2.03it/s]

epoch 191 lr 0.0010000 mse0 191461.18750 mse1 355531.33333 Total Loss 195016.50083


  4%|▍         | 193/5000 [01:30<39:21,  2.04it/s]

epoch 192 lr 0.0010000 mse0 190542.77083 mse1 355991.66667 Total Loss 194102.68750


  4%|▍         | 194/5000 [01:30<39:19,  2.04it/s]

epoch 193 lr 0.0010000 mse0 189629.18750 mse1 356452.54167 Total Loss 193193.71292


  4%|▍         | 195/5000 [01:31<39:03,  2.05it/s]

epoch 194 lr 0.0010000 mse0 188720.43750 mse1 356913.95833 Total Loss 192289.57708


  4%|▍         | 196/5000 [01:31<37:18,  2.15it/s]

epoch 195 lr 0.0010000 mse0 187816.56250 mse1 357375.29167 Total Loss 191390.31542


  4%|▍         | 197/5000 [01:32<37:54,  2.11it/s]

epoch 196 lr 0.0010000 mse0 186917.47917 mse1 357835.95833 Total Loss 190495.83875


  4%|▍         | 198/5000 [01:32<37:51,  2.11it/s]

epoch 197 lr 0.0010000 mse0 186023.20833 mse1 358295.83333 Total Loss 189606.16667


  4%|▍         | 199/5000 [01:33<37:57,  2.11it/s]

epoch 198 lr 0.0010000 mse0 185133.70833 mse1 358754.12500 Total Loss 188721.24958


  4%|▍         | 200/5000 [01:33<38:00,  2.10it/s]

epoch 199 lr 0.0010000 mse0 184248.95833 mse1 359210.79167 Total Loss 187841.06625


  4%|▍         | 201/5000 [01:34<37:16,  2.15it/s]

epoch 200 lr 0.0010000 mse0 183368.87500 mse1 359665.20833 Total Loss 186965.52708


  4%|▍         | 202/5000 [01:34<37:10,  2.15it/s]

epoch 201 lr 0.0010000 mse0 182493.52083 mse1 360116.75000 Total Loss 186094.68833


  4%|▍         | 203/5000 [01:35<37:40,  2.12it/s]

epoch 202 lr 0.0010000 mse0 181622.62500 mse1 360564.95833 Total Loss 185228.27458


  4%|▍         | 204/5000 [01:35<36:48,  2.17it/s]

epoch 203 lr 0.0010000 mse0 180756.31250 mse1 361009.41667 Total Loss 184366.40667


  4%|▍         | 205/5000 [01:35<36:20,  2.20it/s]

epoch 204 lr 0.0010000 mse0 179894.52083 mse1 361449.70833 Total Loss 183509.01792


  4%|▍         | 206/5000 [01:36<36:33,  2.19it/s]

epoch 205 lr 0.0010000 mse0 179037.04167 mse1 361885.12500 Total Loss 182655.89292


  4%|▍         | 207/5000 [01:36<35:53,  2.23it/s]

epoch 206 lr 0.0010000 mse0 178183.85417 mse1 362315.41667 Total Loss 181807.00833


  4%|▍         | 208/5000 [01:37<36:57,  2.16it/s]

epoch 207 lr 0.0010000 mse0 177334.89583 mse1 362739.91667 Total Loss 180962.29500


  4%|▍         | 209/5000 [01:37<37:06,  2.15it/s]

epoch 208 lr 0.0010000 mse0 176490.02083 mse1 363158.29167 Total Loss 180121.60375


  4%|▍         | 210/5000 [01:38<37:23,  2.13it/s]

epoch 209 lr 0.0010000 mse0 175649.20833 mse1 363570.00000 Total Loss 179284.90833


  4%|▍         | 211/5000 [01:38<37:27,  2.13it/s]

epoch 210 lr 0.0010000 mse0 174812.25000 mse1 363974.54167 Total Loss 178451.99542


  4%|▍         | 212/5000 [01:39<36:13,  2.20it/s]

epoch 211 lr 0.0010000 mse0 173979.06250 mse1 364371.66667 Total Loss 177622.77917


  4%|▍         | 213/5000 [01:39<36:39,  2.18it/s]

epoch 212 lr 0.0010000 mse0 173149.54167 mse1 364760.95833 Total Loss 176797.15125


  4%|▍         | 214/5000 [01:40<37:00,  2.16it/s]

epoch 213 lr 0.0010000 mse0 172323.55208 mse1 365141.91667 Total Loss 175974.97125


  4%|▍         | 215/5000 [01:40<37:06,  2.15it/s]

epoch 214 lr 0.0010000 mse0 171500.91667 mse1 365514.41667 Total Loss 175156.06083


  4%|▍         | 216/5000 [01:41<37:09,  2.15it/s]

epoch 215 lr 0.0010000 mse0 170681.57292 mse1 365878.00000 Total Loss 174340.35292


  4%|▍         | 217/5000 [01:41<37:13,  2.14it/s]

epoch 216 lr 0.0010000 mse0 169865.33333 mse1 366232.54167 Total Loss 173527.65875


  4%|▍         | 218/5000 [01:41<37:04,  2.15it/s]

epoch 217 lr 0.0010000 mse0 169052.02083 mse1 366577.95833 Total Loss 172717.80042


  4%|▍         | 219/5000 [01:42<36:40,  2.17it/s]

epoch 218 lr 0.0010000 mse0 168241.52083 mse1 366914.12500 Total Loss 171910.66208


  4%|▍         | 220/5000 [01:42<36:49,  2.16it/s]

epoch 219 lr 0.0010000 mse0 167433.65625 mse1 367241.08333 Total Loss 171106.06708


  4%|▍         | 221/5000 [01:43<37:18,  2.14it/s]

epoch 220 lr 0.0010000 mse0 166628.23958 mse1 367559.08333 Total Loss 170303.83042


  4%|▍         | 222/5000 [01:43<37:39,  2.11it/s]

epoch 221 lr 0.0010000 mse0 165825.12500 mse1 367868.12500 Total Loss 169503.80625


  4%|▍         | 223/5000 [01:44<37:40,  2.11it/s]

epoch 222 lr 0.0010000 mse0 165024.06250 mse1 368168.83333 Total Loss 168705.75083


  4%|▍         | 224/5000 [01:44<38:00,  2.09it/s]

epoch 223 lr 0.0010000 mse0 164224.92708 mse1 368461.62500 Total Loss 167909.54333


  4%|▍         | 225/5000 [01:45<38:12,  2.08it/s]

epoch 224 lr 0.0010000 mse0 163427.52083 mse1 368747.04167 Total Loss 167114.99125


  5%|▍         | 226/5000 [01:45<38:22,  2.07it/s]

epoch 225 lr 0.0010000 mse0 162631.61458 mse1 369026.12500 Total Loss 166321.87583


  5%|▍         | 227/5000 [01:46<39:00,  2.04it/s]

epoch 226 lr 0.0010000 mse0 161837.02083 mse1 369299.70833 Total Loss 165530.01792


  5%|▍         | 228/5000 [01:46<38:15,  2.08it/s]

epoch 227 lr 0.0010000 mse0 161043.51042 mse1 369568.95833 Total Loss 164739.20000


  5%|▍         | 229/5000 [01:47<37:57,  2.09it/s]

epoch 228 lr 0.0010000 mse0 160250.89583 mse1 369835.37500 Total Loss 163949.24958


  5%|▍         | 230/5000 [01:47<37:58,  2.09it/s]

epoch 229 lr 0.0010000 mse0 159458.91667 mse1 370100.20833 Total Loss 163159.91875


  5%|▍         | 231/5000 [01:48<38:01,  2.09it/s]

epoch 230 lr 0.0010000 mse0 158667.37500 mse1 370365.50000 Total Loss 162371.03000


  5%|▍         | 231/5000 [01:48<37:21,  2.13it/s]


KeyboardInterrupt: 

In [None]:
torch.save(net, "sqnet2.pth")