In [1]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import trimesh
import argparse
import os
import time
import numpy as np
import torch
import torch.optim as optim
import lib.utils as utils
from lib.utils import standard_normal_logprob
from lib.utils import count_nfe, count_total_time
from lib.utils import build_model_tabular

from lib.visualize_flow import visualize_transform
import lib.layers.odefunc as odefunc
from matplotlib import pyplot as plt
%matplotlib inline 

In [2]:
SOLVERS = ["dopri5"]
parser = argparse.ArgumentParser('SoftFlow')
parser.add_argument(
    '--data', choices=['2spirals_1d','2spirals_2d', 'swissroll_1d','swissroll_2d', 'circles_1d', 'circles_2d', '2sines_1d', 'target_1d'],
    type=str, default='3d_microgel'
)
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=["concatsquash"])
parser.add_argument('--dims', type=str, default='64-64-64')
parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.')
parser.add_argument('--time_length', type=float, default=0.5)
parser.add_argument('--train_T', type=eval, default=True)
parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"])
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES)

parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)

parser.add_argument('--residual', type=eval, default=False, choices=[True, False])
parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False])
parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False])
parser.add_argument('--niters', type=int, default=36000)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-5)

# for the proposed method
parser.add_argument('--std_min', type=float, default=0.0)
parser.add_argument('--std_max', type=float, default=0.1)
parser.add_argument('--std_weight', type=float, default=2)

parser.add_argument('--viz_freq', type=int, default=100)
parser.add_argument('--val_freq', type=int, default=400)
parser.add_argument('--log_freq', type=int, default=10)
parser.add_argument('--gpu', type=int, default=0)
# args = parser.parse_args()
args, unknown = parser.parse_known_args()


In [3]:
# logger
save_path = './results/' + args.data + '/SoftFlow'
utils.makedirs(save_path)

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

In [4]:
def get_transforms(model):

    def sample_fn(z, logpz=None):
        zeors_std = torch.zeros(z.shape[0], 1).to(z)
        if logpz is not None:
            return model(z, zeors_std, logpz, reverse=True)
        else:
            return model(z, zeors_std, reverse=True)

    def density_fn(x, logpx=None):
        zeors_std = torch.zeros(x.shape[0], 1).to(x)
        if logpx is not None:
            return model(x, zeors_std, logpx, reverse=False)
        else:
            return model(x, zeors_std, reverse=False)

    return sample_fn, density_fn

DATA = trimesh.load('ANDREY_FOCTS1.ply')[:]
 
def compute_loss(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # load data
#     x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = DATA[:,0:3]


    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    std = (args.std_max - args.std_min) * torch.rand_like(x[:,0]).view(-1,1) + args.std_min
    eps = torch.randn_like(x) * std
    std_in = std / args.std_max * args.std_weight
#     print(x.shape, std_in.shape, zero.shape)
    z, delta_logp = model(x+eps, std_in, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss

In [5]:
model = build_model_tabular(args, 3).to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

time_meter = utils.RunningAverageMeter(0.93)
loss_meter = utils.RunningAverageMeter(0.93)
nfef_meter = utils.RunningAverageMeter(0.93)
nfeb_meter = utils.RunningAverageMeter(0.93)
tt_meter = utils.RunningAverageMeter(0.93)

end = time.time()
best_loss = float('inf')

In [6]:
LOW = -4
HIGH = 4
def visualize3d(samples, prior_sample, prior_logdensity, n_samples, transform, inverse_transform, memory=100, title1="$x ~ p(x)$", title2="$q(x)$", title3="$x ~ q(x)$", device="cpu", npts = 100):
    ## first one
    fig = plt.figure(figsize=(25,5))
    ax1 = fig.add_subplot(1, 3, 1, projection="3d")
    ax1.scatter(samples[:, 0], samples[:, 1], samples[:, 2], s=1, c='b', marker="s", facecolor="red", lw=0, alpha=1)
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax1.set_title(title1)
    
    ## second one
    side = np.linspace(LOW, HIGH, npts)
    xx, yy, zz = np.meshgrid(side, side, side)
    memory=100
    x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)])
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zeros = torch.zeros(x.shape[0], 1).to(x)

    z, delta_logp = [], []
    inds = torch.arange(0, x.shape[0]).to(torch.int64)
    for ii in torch.split(inds, int(memory**2)):
        z_, delta_logp_ = inverse_transform(x[ii], zeros[ii])
        z.append(z_)
        delta_logp.append(delta_logp_)
    z = torch.cat(z, 0)
    delta_logp = torch.cat(delta_logp, 0)

    logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp

    px = np.exp(logpx.detach().cpu().numpy()).reshape(npts, npts,npts)
    ax = fig.add_subplot(1, 3, 2)
    ax.imshow(px[:,:,1])
    ax.set_title(title2)
    
    ## Third one
    z = prior_sample(n_samples, 3).type(torch.float32).to(device)
    zk = []
    inds = torch.arange(0, z.shape[0]).to(torch.int64)
    for ii in torch.split(inds, int(memory**2)):
        zk.append(transform(z[ii]))

    zk = torch.cat(zk, 0).detach().cpu().numpy()

    ax1 = fig.add_subplot(1, 3, 3, projection="3d")
    ax1.scatter(zk[:, 0], zk[:, 1], zk[:, 2], s=1, c='r', marker="s", facecolor="red", lw=0, alpha=1)
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    ax1.set_title(title3)
    
# plt.clf()    
# p_samples = DATA[:,0:3]
# visualize3d(p_samples, torch.randn, standard_normal_logprob, n_samples, sample_fn,density_fn, device=device)
# # plt.show()
# fig_filename = os.path.join('./', 'figs', '{:04d}.jpg'.format(1))
# utils.makedirs(os.path.dirname(fig_filename))
# plt.savefig(fig_filename, format='png', dpi=1200)
# plt.close()

In [None]:
model.train()
for itr in range(1, args.niters + 1):
    optimizer.zero_grad()

    loss = compute_loss(args, model)
    

    loss_meter.update(loss.item())

    total_time = count_total_time(model)
    nfe_forward = count_nfe(model)

    loss.backward()
    optimizer.step()

    nfe_total = count_nfe(model)
    nfe_backward = nfe_total - nfe_forward
    nfef_meter.update(nfe_forward)
    nfeb_meter.update(nfe_backward)

    time_meter.update(time.time() - end)
    tt_meter.update(total_time)

    log_message = (
        'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})'
        ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
            itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg,
            nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg
        )
    )

    print(log_message)
    
    if itr % args.val_freq == 0 or itr == args.niters:
        with torch.no_grad():
            model.eval()
            test_loss = compute_loss(args, model, batch_size=args.test_batch_size)
            test_nfe = count_nfe(model)
            log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe)
            print(log_message)

            if test_loss.item() < best_loss:
                best_loss = test_loss.item()
                utils.makedirs(save_path)
                torch.save({
                    'args': args,
                    'state_dict': model.state_dict(),
                }, os.path.join(save_path, 'checkpt.pth'))
            model.train()

    if itr % args.viz_freq == 0:
        with torch.no_grad():
            model.eval()
#             p_samples = toy_data.inf_train_gen(args.data, batch_size=2000)
            p_samples = DATA[:,0:3]

            sample_fn, density_fn = get_transforms(model)

            plt.figure(figsize=(9, 3))
            visualize3d(p_samples, torch.randn, standard_normal_logprob, p_samples.shape[0], sample_fn,density_fn, device=device)
            fig_filename = os.path.join(save_path, 'figs', '{:04d}.jpg'.format(itr))
            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename, format='png', dpi=1200)
            plt.close()
            model.train()

    end = time.time()

print('Training has finished.')

Iter 0001 | Time 88.0963(7.1036) | Loss -0.566695(0.131992) | NFE Forward 38(36.8) | NFE Backward 44(40.3) | CNF Time 0.5000(0.5000)
Iter 0002 | Time 1.0714(6.6814) | Loss -0.588127(0.081584) | NFE Forward 38(36.9) | NFE Backward 44(40.5) | CNF Time 0.5000(0.5000)
Iter 0003 | Time 1.0694(6.2886) | Loss -0.607913(0.033319) | NFE Forward 38(37.0) | NFE Backward 44(40.8) | CNF Time 0.5000(0.5000)
Iter 0004 | Time 1.0683(5.9231) | Loss -0.623924(-0.012688) | NFE Forward 38(37.1) | NFE Backward 44(41.0) | CNF Time 0.5000(0.5000)
Iter 0005 | Time 1.0689(5.5833) | Loss -0.635768(-0.056304) | NFE Forward 38(37.1) | NFE Backward 44(41.2) | CNF Time 0.5000(0.5000)
Iter 0006 | Time 1.0750(5.2678) | Loss -0.650307(-0.097884) | NFE Forward 38(37.2) | NFE Backward 44(41.4) | CNF Time 0.5000(0.5000)
Iter 0007 | Time 1.1773(4.9814) | Loss -0.656779(-0.137006) | NFE Forward 38(37.2) | NFE Backward 50(42.0) | CNF Time 0.5000(0.5000)
Iter 0008 | Time 1.1732(4.7149) | Loss -0.663374(-0.173852) | NFE Forwa

Iter 0063 | Time 1.0755(1.1406) | Loss -0.945792(-0.885305) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0064 | Time 1.0759(1.1361) | Loss -0.949931(-0.889829) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0065 | Time 1.0742(1.1318) | Loss -0.959375(-0.894697) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0066 | Time 1.0744(1.1278) | Loss -0.959736(-0.899250) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0067 | Time 1.0750(1.1241) | Loss -0.962586(-0.903684) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0068 | Time 1.0764(1.1207) | Loss -0.967235(-0.908132) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0069 | Time 1.0743(1.1175) | Loss -0.970916(-0.912527) | NFE Forward 38(38.0) | NFE Backward 44(44.0) | CNF Time 0.5000(0.5000)
Iter 0070 | Time 1.0736(1.1144) | Loss -0.970099(-0.916557) | NFE For

Iter 0125 | Time 1.3320(1.3181) | Loss -1.140777(-1.104668) | NFE Forward 44(43.8) | NFE Backward 56(55.3) | CNF Time 0.5000(0.5000)
Iter 0126 | Time 1.3347(1.3193) | Loss -1.148244(-1.107718) | NFE Forward 44(43.8) | NFE Backward 56(55.4) | CNF Time 0.5000(0.5000)
Iter 0127 | Time 1.3268(1.3198) | Loss -1.149400(-1.110636) | NFE Forward 44(43.9) | NFE Backward 56(55.4) | CNF Time 0.5000(0.5000)
Iter 0128 | Time 1.3244(1.3201) | Loss -1.147625(-1.113225) | NFE Forward 44(43.9) | NFE Backward 56(55.5) | CNF Time 0.5000(0.5000)
Iter 0129 | Time 1.3274(1.3206) | Loss -1.152082(-1.115945) | NFE Forward 44(43.9) | NFE Backward 56(55.5) | CNF Time 0.5000(0.5000)
Iter 0130 | Time 1.3444(1.3223) | Loss -1.158681(-1.118937) | NFE Forward 44(43.9) | NFE Backward 56(55.5) | CNF Time 0.5000(0.5000)
Iter 0131 | Time 1.3408(1.3236) | Loss -1.160707(-1.121861) | NFE Forward 44(43.9) | NFE Backward 56(55.6) | CNF Time 0.5000(0.5000)
Iter 0132 | Time 1.3328(1.3243) | Loss -1.161701(-1.124649) | NFE For

Iter 0187 | Time 1.3316(1.3533) | Loss -1.282629(-1.256765) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0188 | Time 1.3338(1.3519) | Loss -1.283933(-1.258667) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0189 | Time 1.3265(1.3502) | Loss -1.289182(-1.260803) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0190 | Time 1.3303(1.3488) | Loss -1.289520(-1.262813) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0191 | Time 1.3357(1.3478) | Loss -1.291216(-1.264801) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0192 | Time 1.3386(1.3472) | Loss -1.299143(-1.267205) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0193 | Time 1.3284(1.3459) | Loss -1.294411(-1.269110) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0194 | Time 1.3320(1.3449) | Loss -1.302255(-1.271430) | NFE For

Iter 0249 | Time 1.3273(1.3325) | Loss -1.373490(-1.358721) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0250 | Time 1.3326(1.3325) | Loss -1.373605(-1.359763) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0251 | Time 1.3331(1.3325) | Loss -1.377375(-1.360996) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0252 | Time 1.3329(1.3326) | Loss -1.373293(-1.361857) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0253 | Time 1.3325(1.3326) | Loss -1.383817(-1.363394) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0254 | Time 1.3322(1.3325) | Loss -1.377068(-1.364351) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0255 | Time 1.3341(1.3326) | Loss -1.374746(-1.365079) | NFE Forward 44(44.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0256 | Time 1.3270(1.3323) | Loss -1.379850(-1.366113) | NFE For

Iter 0311 | Time 1.4631(1.3786) | Loss -1.409263(-1.406057) | NFE Forward 50(48.5) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0312 | Time 1.4009(1.3801) | Loss -1.412108(-1.406480) | NFE Forward 50(48.6) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0313 | Time 1.3755(1.3798) | Loss -1.412314(-1.406889) | NFE Forward 50(48.7) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0314 | Time 1.3996(1.3812) | Loss -1.414541(-1.407424) | NFE Forward 50(48.8) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0315 | Time 1.4012(1.3826) | Loss -1.413072(-1.407820) | NFE Forward 50(48.9) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0316 | Time 1.3730(1.3819) | Loss -1.416695(-1.408441) | NFE Forward 50(48.9) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0317 | Time 1.3695(1.3811) | Loss -1.412418(-1.408719) | NFE Forward 50(49.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0318 | Time 1.3681(1.3801) | Loss -1.411880(-1.408941) | NFE For

Iter 0373 | Time 1.3779(1.3761) | Loss -1.431711(-1.429367) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0374 | Time 1.3786(1.3763) | Loss -1.434256(-1.429709) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0375 | Time 1.3705(1.3759) | Loss -1.431683(-1.429848) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0376 | Time 1.3723(1.3756) | Loss -1.435545(-1.430246) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0377 | Time 1.3769(1.3757) | Loss -1.436133(-1.430658) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0378 | Time 1.3727(1.3755) | Loss -1.432526(-1.430789) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0379 | Time 1.3770(1.3756) | Loss -1.435940(-1.431150) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0380 | Time 1.3781(1.3758) | Loss -1.433767(-1.431333) | NFE For

Iter 0435 | Time 1.3724(1.3745) | Loss -1.446566(-1.442969) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0436 | Time 1.3693(1.3742) | Loss -1.445422(-1.443141) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0437 | Time 1.3695(1.3738) | Loss -1.445973(-1.443339) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0438 | Time 1.3695(1.3735) | Loss -1.448839(-1.443724) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0439 | Time 1.3749(1.3736) | Loss -1.446653(-1.443929) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0440 | Time 1.3784(1.3740) | Loss -1.442754(-1.443847) | NFE Forward 50(50.0) | NFE Backward 56(56.0) | CNF Time 0.5000(0.5000)
Iter 0441 | Time 1.4769(1.3812) | Loss -1.448241(-1.444154) | NFE Forward 50(50.0) | NFE Backward 62(56.4) | CNF Time 0.5000(0.5000)
Iter 0442 | Time 1.3738(1.3807) | Loss -1.442796(-1.444059) | NFE For

Iter 0497 | Time 1.5271(1.4537) | Loss -1.453966(-1.451628) | NFE Forward 56(50.8) | NFE Backward 62(60.0) | CNF Time 0.5000(0.5000)
Iter 0498 | Time 1.4197(1.4514) | Loss -1.457101(-1.452011) | NFE Forward 56(51.2) | NFE Backward 56(59.7) | CNF Time 0.5000(0.5000)
Iter 0499 | Time 1.4854(1.4537) | Loss -1.455349(-1.452245) | NFE Forward 50(51.1) | NFE Backward 62(59.9) | CNF Time 0.5000(0.5000)
Iter 0500 | Time 1.4793(1.4555) | Loss -1.456102(-1.452515) | NFE Forward 50(51.0) | NFE Backward 62(60.1) | CNF Time 0.5000(0.5000)
Iter 0501 | Time 1.4354(1.4541) | Loss -1.454025(-1.452621) | NFE Forward 56(51.4) | NFE Backward 56(59.8) | CNF Time 0.5000(0.5000)
Iter 0502 | Time 1.5423(1.4603) | Loss -1.453158(-1.452658) | NFE Forward 50(51.3) | NFE Backward 62(59.9) | CNF Time 0.5000(0.5000)
Iter 0503 | Time 1.4227(1.4577) | Loss -1.453371(-1.452708) | NFE Forward 50(51.2) | NFE Backward 56(59.6) | CNF Time 0.5000(0.5000)
Iter 0504 | Time 1.5056(1.4610) | Loss -1.452225(-1.452674) | NFE For

Iter 0559 | Time 1.3752(1.4475) | Loss -1.456967(-1.459538) | NFE Forward 50(50.3) | NFE Backward 56(60.0) | CNF Time 0.5000(0.5000)
Iter 0560 | Time 1.3739(1.4423) | Loss -1.463054(-1.459784) | NFE Forward 50(50.2) | NFE Backward 56(59.7) | CNF Time 0.5000(0.5000)
Iter 0561 | Time 1.3761(1.4377) | Loss -1.464034(-1.460082) | NFE Forward 50(50.2) | NFE Backward 56(59.5) | CNF Time 0.5000(0.5000)
Iter 0562 | Time 1.7012(1.4561) | Loss -1.465761(-1.460479) | NFE Forward 50(50.2) | NFE Backward 74(60.5) | CNF Time 0.5000(0.5000)
Iter 0563 | Time 1.3746(1.4504) | Loss -1.461354(-1.460540) | NFE Forward 50(50.2) | NFE Backward 56(60.2) | CNF Time 0.5000(0.5000)
Iter 0564 | Time 1.3756(1.4452) | Loss -1.461566(-1.460612) | NFE Forward 50(50.2) | NFE Backward 56(59.9) | CNF Time 0.5000(0.5000)
Iter 0565 | Time 1.3696(1.4399) | Loss -1.464082(-1.460855) | NFE Forward 50(50.2) | NFE Backward 56(59.6) | CNF Time 0.5000(0.5000)
Iter 0566 | Time 1.3716(1.4351) | Loss -1.466853(-1.461275) | NFE For