In [1]:
from tqdm.auto import trange

import numpy as np
import plotly.graph_objects as go

import torch
from torch.utils.data import DataLoader

import os
import sys
import datetime
import gc

from IPython.display import SVG, display

def imshow(fig):
    return display(SVG(fig.to_image(format="svg")))

%matplotlib inline
%config InlineBackend.figure_format='retina'

eps = 1e-10  # A negligible positive number
np.random.seed(0)


In [2]:
%cd /home/kevin/AutoInt
from src.utils import get_device
from src.integration.autoint import *
#from src.models.st_model import AutoIntSTPPSameInfluence
from src.utils import AverageMeter
from tqdm.contrib import tenumerate
from copy import deepcopy
from src.data.data import SlidingWindowWrapper, TPPWrapper, pad_collate

device = get_device(free=False, min_ram=0)

/home/kevin/AutoInt


In [3]:
def eval_loss(model, test_loader):
    model.eval()
    sll_meter = AverageMeter()
    tll_meter = AverageMeter()
    loss_meter = AverageMeter()

    for index, data in enumerate(test_loader):
        st_x, st_y, _, _, _ = data
        loss, sll, tll = model(st_x, st_y)

        loss_meter.update(loss.item())
        sll_meter.update(sll.mean().item())
        tll_meter.update(tll.mean().item())

    return loss_meter.avg, sll_meter.avg, tll_meter.avg

In [4]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

config = Namespace(hid_dim=128, emb_dim=128, out_dim=0, n_layers=1,
                   lr=0.001, momentum=0.9, epochs=100, batch=128, opt='Adam', generate_type=True,
                   read_model=False, seq_len=20, eval_epoch=5, s_min=1e-3, b_max=20,
                   lookahead=1, alpha=0.1, z_dim=128, beta=1e-3, dropout=0, num_head=2,
                   nlayers=3, num_points=20, infer_nstep=10000, infer_limit=13, clip=1.0,
                   constrain_b='sigmoid', sample=True, decoder_n_layer=3)

In [5]:
import torch
from torch import nn

from src.integration.autoint import Cuboid


class AutoIntSTPPSameInfluence(nn.Module):

    def __init__(self, hidden_size, device):
        """
        :param: hidden_size: the dimension of linear hidden layer
        :param: t_end: the time when observation terminates
                if is None, then time after last event is not considered
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.device = device

        # log background intensity
        self.background = torch.nn.Parameter(torch.ones(1))

        # ∫_0^t λ
        self.F = Cuboid().to(device)

        self.project()

    def project(self):
        """
        Employ non-negative constraint
        """
        self.F.project()

    def forward(self, st_x, st_y):
        """
        Calculate NLL for a batch of sliding windows

        :param st_x: [batch, seq_len, 3], the event timings
        :param st_y: [batch, 1, 3], the time to forecast
        :return: nll: scalar, the average negative log likelihood
        """

        # Calculate spatiotemporal distance to previous events
        t_x_cum = torch.cumsum(st_x[..., -1], -1)  # [batch, seq_len]
        t_diff = t_x_cum[:, -1:] - t_x_cum + st_y[..., -1:, -1]  # [batch, seq_len]

        if not torch.all(t_diff >= 0):
            idx = torch.argmin(t_diff)
            print(t_diff[idx // t_diff.shape[1]])
            raise

        s_x = st_x[..., :2]
        s_y = st_y[..., :2]
        s_diff = s_y - s_x   # [batch, seq_len, 2]
        st_diff = torch.cat([s_diff, t_diff.unsqueeze(-1)], -1)

        ########## Calculate intensity ############
        # [batch, seq_len]
        batch, seq_len, _ = st_diff.shape
        lambs = self.F.forward(st_diff.view(-1, 3)).view([batch, seq_len])
        temp = lambs  # TODO
        # print(lambs[0])
        lambs = torch.sum(lambs, -1) + torch.exp(self.background)  # sum up all events' influence

        ########## Calculate temporal intensity ############
        lamb_t = self.F.lamb_t(s_x.view(-1, 2), t_diff.view(-1)).view([batch, seq_len])
        lamb_t = torch.sum(lamb_t, -1) + torch.exp(self.background)  # sum up all events' influence

        ######### Calculate integral intensity ##########
        # [batch, seq_len]
        # cumulative intensity of every event
        lamb_ints = self.F.int_lamb(s_x.view(-1, 2), (t_x_cum[:, -1:] - t_x_cum).view(-1), t_diff.view(-1)).view([batch, seq_len])

        ######### Calculate loss ########
        lamb_ints = torch.sum(lamb_ints, -1)
        background_int = st_y[..., -1, -1] * torch.exp(self.background)
        lamb_ints += background_int  # Add background intensities' integral

        tll = torch.log(lamb_t).mean() - lamb_ints.mean()
        ll = torch.log(lambs).mean() - lamb_ints.mean()

        if not torch.all(lambs > 0):
            idx = torch.argmax(torch.isnan(torch.log(lambs)).float())
            print(idx)
            print(t_diff[idx])
            print(temp[idx])
            print(torch.sum(temp[idx]))
            print(torch.log(torch.sum(temp[idx])))
            print(lambs)
            print(torch.log(lambs))
            print('--------------------------------------------')
            raise

        sll = ll - tll

        return -ll, sll, tll


## COVID NJ Case

In [6]:
dataset = 'covid_nj_cases'
npzf = np.load(f'data/spatiotemporal/{dataset}.npz', allow_pickle=True)

In [7]:
trainset = SlidingWindowWrapper(npzf['train'], normalized=True)
valset   = SlidingWindowWrapper(npzf['val'],   normalized=True, min=trainset.min, max=trainset.max)
testset  = SlidingWindowWrapper(npzf['test'],  normalized=True, min=trainset.min, max=trainset.max)

In [8]:
train_loader = DataLoader(trainset, batch_size=config.batch, shuffle=True)
val_loader   = DataLoader(valset,   batch_size=config.batch, shuffle=False)
test_loader  = DataLoader(testset,  batch_size=config.batch, shuffle=False)

In [9]:
# st_x = torch.rand(128, 20, 2).to(device)
# st_y = torch.rand(128, 1, 2).to(device)
#
# model = AutoIntSTPPSameInfluence(config.hid_dim, t_end=1000.0, device=device).to(device)
# model(st_x, st_y)

In [10]:
model = AutoIntSTPPSameInfluence(config.hid_dim, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
best_eval = np.infty
sll_meter = AverageMeter()
tll_meter = AverageMeter()
loss_meter = AverageMeter()

time_now = str(datetime.datetime.now())
parent_dir = f'models/AutoInt-STPP-Same-Influence-{dataset}-{time_now}'
os.mkdir(parent_dir)

for epoch in trange(config.epochs):
    loss_total = 0
    model.train()
    for index, data in tenumerate(train_loader):
        st_x, st_y, _, _, _ = data
        optimizer.zero_grad()
        loss, sll, tll = model(st_x, st_y)

        if torch.isnan(loss):
            print("Numerical error, quiting...")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
        optimizer.step()

        # Project to feasible set
        model.project()

        loss_meter.update(loss.item())
        sll_meter.update(sll.mean().item())
        tll_meter.update(tll.mean().item())

    scheduler.step()

    print("In epochs {} | "
                "total loss: {:5f} | Space: {:5f} | Time: {:5f}".format(
        epoch, loss_meter.avg, sll_meter.avg , tll_meter.avg
    ))

    if (epoch + 1) % config.eval_epoch == 0:
        model.eval()
        valloss, valspace, valtime = eval_loss(model, val_loader)
        print("Evaluate   | Val Loss {:5f} | Space: {:5f} | Time: {:5f}".format(valloss, valspace, valtime))
        if valloss < best_eval:
            best_eval = valloss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, f'{parent_dir}/AutoIntSTPP-{epoch}.mod')

print("training done!")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/901 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Earthquakes JP

In [None]:
config = Namespace(hid_dim=128, emb_dim=128, out_dim=0,
                   lr=0.001, momentum=0.9, epochs=100, batch=512, opt='Adam', generate_type=True,
                   read_model=False, seq_len=20, eval_epoch=5, s_min=1e-3, b_max=20,
                   lookahead=1, alpha=0.1, z_dim=128, beta=1e-3, dropout=0, num_head=2,
                   nlayers=3, num_points=20, infer_nstep=10000, infer_limit=13, clip=0.5,
                   constrain_b='sigmoid', sample=True, decoder_n_layer=3)

In [None]:
dataset = 'earthquakes_jp'
npzf = np.load(f'data/spatiotemporal/{dataset}.npz', allow_pickle=True)

In [None]:
trainset = SlidingWindowWrapper(npzf['train'], normalized=True)
valset   = SlidingWindowWrapper(npzf['val'],   normalized=True, min=trainset.min, max=trainset.max)
testset  = SlidingWindowWrapper(npzf['test'],  normalized=True, min=trainset.min, max=trainset.max)

In [None]:
train_loader = DataLoader(trainset, batch_size=config.batch, shuffle=True)
val_loader   = DataLoader(valset,   batch_size=config.batch, shuffle=False)
test_loader  = DataLoader(testset,  batch_size=config.batch, shuffle=False)

In [None]:
from tqdm import tqdm

In [None]:
model = AutoIntSTPPSameInfluence(config.hid_dim, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
best_eval = np.infty
sll_meter = AverageMeter()
tll_meter = AverageMeter()
loss_meter = AverageMeter()

time_now = str(datetime.datetime.now())
parent_dir = f'models/AutoInt-STPP-Same-Influence-{dataset}-{time_now}'
os.mkdir(parent_dir)

for epoch in trange(config.epochs):
    loss_total = 0
    model.train()
    for index, data in tenumerate(train_loader):
        st_x, st_y, _, _, _ = data
        optimizer.zero_grad()
        loss, sll, tll = model(st_x, st_y)

        if torch.isnan(loss):
            print("Numerical error, quiting...")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
        optimizer.step()

        # Project to feasible set
        model.project()

        loss_meter.update(loss.item())
        sll_meter.update(sll.mean().item())
        tll_meter.update(tll.mean().item())

    scheduler.step()

    print("In epochs {} | "
                "total loss: {:5f} | Space: {:5f} | Time: {:5f}".format(
        epoch, loss_meter.avg, sll_meter.avg , tll_meter.avg
    ))

    if (epoch + 1) % config.eval_epoch == 0:
        model.eval()
        valloss, valspace, valtime = eval_loss(model, val_loader)
        print("Evaluate     | val loss:   {:5f} | Space: {:5f} | Time: {:5f}".format(valloss, valspace, valtime))
        if valloss < best_eval:
            best_eval = valloss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, f'{parent_dir}/AutoIntSTPP-{epoch}.mod')

print("training done!")

In epochs 5 | total loss: -1.881053 | Space: 0.077642 | Time: 1.803411


  0%|          | 0/104 [00:00<?, ?it/s]

In epochs 6 | total loss: -1.885927 | Space: 0.081679 | Time: 1.804248


  0%|          | 0/104 [00:00<?, ?it/s]

In epochs 7 | total loss: -1.889637 | Space: 0.084764 | Time: 1.804873


  0%|          | 0/104 [00:00<?, ?it/s]

In [14]:
model.F.L.layers.layers[0].weight

Parameter containing:
tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
   

In [None]:

            print(t_diff[idx])
            print(temp[idx])
            print(torch.log(lambs))
            print(torch.log(lambs).mean())
            print(lamb_t)
            print(lamb_ints)
            print(ll)
            print(tll)

## STHP TS1

In [35]:
config = Namespace(hid_dim=128, emb_dim=128, out_dim=0, n_layers=1,
                   lr=0.0003, momentum=0.9, epochs=50, batch=128, opt='Adam', generate_type=True,
                   read_model=False, seq_len=20, eval_epoch=5, s_min=1e-4, b_max=20,
                   lookahead=1, alpha=0.1, z_dim=128, beta=1e-3, dropout=0, num_head=2,
                   nlayers=3, num_points=20, infer_nstep=10000, infer_limit=13, clip=1.0,
                   constrain_b=False, sample=False, decoder_n_layer=3)

In [36]:
dataset = 'sthp1'
npzf = np.load(f'data/spatiotemporal/{dataset}.npz', allow_pickle=True)

In [37]:
trainset = SlidingWindowWrapper(npzf['train'], normalized=True)
valset   = SlidingWindowWrapper(npzf['val'],   normalized=True, min=trainset.min, max=trainset.max)
testset  = SlidingWindowWrapper(npzf['test'],  normalized=True, min=trainset.min, max=trainset.max)

In [38]:
train_loader = DataLoader(trainset, batch_size=config.batch, shuffle=True)
val_loader   = DataLoader(valset,   batch_size=config.batch, shuffle=False)
test_loader  = DataLoader(testset,  batch_size=config.batch, shuffle=False)


In [39]:
model = AutoIntSTPPSameInfluence(config.hid_dim, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
best_eval = np.infty
sll_meter = AverageMeter()
tll_meter = AverageMeter()
loss_meter = AverageMeter()

time_now = str(datetime.datetime.now())
parent_dir = f'models/AutoInt-STPP-Same-Influence-{dataset}-{time_now}'
os.mkdir(parent_dir)

for epoch in trange(config.epochs):
    loss_total = 0
    model.train()
    for index, data in tenumerate(train_loader):
        st_x, st_y, _, _, _ = data
        optimizer.zero_grad()
        loss, sll, tll = model(st_x, st_y)

        if torch.isnan(loss):
            print("Numerical error, quiting...")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)
        optimizer.step()

        # Project to feasible set
        model.project()

        loss_meter.update(loss.item())
        sll_meter.update(sll.mean().item())
        tll_meter.update(tll.mean().item())

    scheduler.step()

    print("In epochs {} | "
                "total loss: {:5f} | Space: {:5f} | Time: {:5f}".format(
        epoch, loss_meter.avg, sll_meter.avg , tll_meter.avg
    ))

    if (epoch + 1) % config.eval_epoch == 0:
        model.eval()
        valloss, valspace, valtime = eval_loss(model, val_loader)
        print("Evaluate     | val loss:   {:5f} | Space: {:5f} | Time: {:5f}".format(valloss, valspace, valtime))
        if valloss < best_eval:
            best_eval = valloss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, f'{parent_dir}/AutoIntSTPP-{epoch}.mod')

print("training done!")


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 0 | total loss: -1.767626 | Space: 0.001709 | Time: 1.765916


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 1 | total loss: -2.311365 | Space: 0.001964 | Time: 2.309401


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 2 | total loss: -2.529467 | Space: 0.002144 | Time: 2.527323


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 3 | total loss: -2.639349 | Space: 0.002354 | Time: 2.636995


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 4 | total loss: -2.704216 | Space: 0.002501 | Time: 2.701715
Evaluate     | val loss:   -2.222175 | Space: 0.017257 | Time: 2.204918


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 5 | total loss: -2.745920 | Space: 0.002538 | Time: 2.743382


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 6 | total loss: -2.777150 | Space: 0.002622 | Time: 2.774528


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 7 | total loss: -2.801980 | Space: 0.002632 | Time: 2.799347


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 8 | total loss: -2.820732 | Space: 0.002617 | Time: 2.818115


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 9 | total loss: -2.835508 | Space: 0.002619 | Time: 2.832890
Evaluate     | val loss:   -2.205857 | Space: 0.017368 | Time: 2.188489


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 10 | total loss: -2.847785 | Space: 0.002650 | Time: 2.845134


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 11 | total loss: -2.858325 | Space: 0.002650 | Time: 2.855675


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 12 | total loss: -2.867311 | Space: 0.002683 | Time: 2.864628


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 13 | total loss: -2.874045 | Space: 0.002689 | Time: 2.871356


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 14 | total loss: -2.880618 | Space: 0.002701 | Time: 2.877917
Evaluate     | val loss:   -2.205031 | Space: 0.017376 | Time: 2.187655


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 15 | total loss: -2.886276 | Space: 0.002707 | Time: 2.883569


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 16 | total loss: -2.891075 | Space: 0.002722 | Time: 2.888354


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 17 | total loss: -2.895272 | Space: 0.002734 | Time: 2.892537


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 18 | total loss: -2.899233 | Space: 0.002743 | Time: 2.896490


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 19 | total loss: -2.902888 | Space: 0.002751 | Time: 2.900136
Evaluate     | val loss:   -2.204557 | Space: 0.017378 | Time: 2.187178


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 20 | total loss: -2.904761 | Space: 0.002762 | Time: 2.901999


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 21 | total loss: -2.907703 | Space: 0.002766 | Time: 2.904938


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 22 | total loss: -2.910394 | Space: 0.002771 | Time: 2.907623


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 23 | total loss: -2.913009 | Space: 0.002769 | Time: 2.910240


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 24 | total loss: -2.915303 | Space: 0.002776 | Time: 2.912527
Evaluate     | val loss:   -2.204487 | Space: 0.017379 | Time: 2.187108


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 25 | total loss: -2.917167 | Space: 0.002783 | Time: 2.914384


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 26 | total loss: -2.919127 | Space: 0.002784 | Time: 2.916344


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 27 | total loss: -2.920913 | Space: 0.002788 | Time: 2.918125


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 28 | total loss: -2.922653 | Space: 0.002787 | Time: 2.919866


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 29 | total loss: -2.924175 | Space: 0.002793 | Time: 2.921383
Evaluate     | val loss:   -2.204478 | Space: 0.017379 | Time: 2.187099


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 30 | total loss: -2.925695 | Space: 0.002787 | Time: 2.922909


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 31 | total loss: -2.927079 | Space: 0.002784 | Time: 2.924296


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 32 | total loss: -2.927894 | Space: 0.002791 | Time: 2.925103


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 33 | total loss: -2.929217 | Space: 0.002798 | Time: 2.926419


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 34 | total loss: -2.930166 | Space: 0.002790 | Time: 2.927375
Evaluate     | val loss:   -2.204478 | Space: 0.017379 | Time: 2.187100


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 35 | total loss: -2.930903 | Space: 0.002792 | Time: 2.928111


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 36 | total loss: -2.931986 | Space: 0.002789 | Time: 2.929198


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 37 | total loss: -2.933133 | Space: 0.002793 | Time: 2.930340


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 38 | total loss: -2.934240 | Space: 0.002787 | Time: 2.931452


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 39 | total loss: -2.935206 | Space: 0.002796 | Time: 2.932410
Evaluate     | val loss:   -2.204477 | Space: 0.017379 | Time: 2.187099


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 40 | total loss: -2.936049 | Space: 0.002794 | Time: 2.933255


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 41 | total loss: -2.936923 | Space: 0.002789 | Time: 2.934133


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 42 | total loss: -2.937688 | Space: 0.002794 | Time: 2.934895


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 43 | total loss: -2.938367 | Space: 0.002799 | Time: 2.935568


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 44 | total loss: -2.939044 | Space: 0.002805 | Time: 2.936239
Evaluate     | val loss:   -2.204477 | Space: 0.017379 | Time: 2.187099


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 45 | total loss: -2.939549 | Space: 0.002809 | Time: 2.936740


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 46 | total loss: -2.940173 | Space: 0.002810 | Time: 2.937363


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 47 | total loss: -2.940779 | Space: 0.002802 | Time: 2.937977


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 48 | total loss: -2.941175 | Space: 0.002805 | Time: 2.938370


  0%|          | 0/50 [00:00<?, ?it/s]

In epochs 49 | total loss: -2.941272 | Space: 0.002809 | Time: 2.938463
Evaluate     | val loss:   -2.204477 | Space: 0.017379 | Time: 2.187099
training done!
