In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import scipy.sparse as sp
import scipy.linalg as linalg

from haversine import haversine
from tqdm import tqdm
from einops import rearrange
from time import time
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
if project_dir not in sys.path:
    sys.path.append(project_dir)

from utils.data import load_data
from utils.tool import prediction_summary

# Prepare Data

In [2]:
os.chdir(project_dir)
train_loader, val_loader, test_loader = load_data(window=2, batch_size=32)
os.chdir(os.path.join(project_dir, 'models/GNN10'))

File1: 3BAGEmnnQ2K4zF49Dkkoxg.csv contains missing hours
File4: 4XEJFVFOS761cvyEjOYf0g.csv contains outliers
File5: 6kzhfU9xTKCUVJMz492l2g.csv contains outliers
File6: 6nBLCf6WT06TOuUExPkBtA.csv contains missing hours
File17: JQ1px-xqQx-xKh3Oa5h9nA.csv contains missing hours
File21: OfAvTbS1SiOjQo4WKSAP9g.csv contains missing hours
File24: R2ebpAblQHylOjteA-2hlQ.csv contains missing hours
File37: jDYxIP2JQL2br5aTIAR7JQ.csv contains outliers
File38: kyRUtBOTTaK7V_-dxOJTwg.csv contains outliers
File45: wSo2iRgjT36eWC4a2joWZg.csv contains outliers


In [3]:
class MyGraphDataset(Dataset):
    def __init__(self, graphs):
        super(MyGraphDataset, self).__init__()
        self.graphs = graphs

    def len(self):
        return len(self.graphs)

    def get(self, idx):
        return self.graphs[idx]

# Build Model

In [4]:
# functions
def get_edge_index(n):
    """
    Create edge index for a complete graph of n nodes with self loops
    :param n: number of nodes
    :return: edge index tensor
    """
    edge_index = torch.tensor([[i, j] for i in range(n) for j in range(n)], dtype=torch.long).t().contiguous()
    return edge_index

def get_distance_matrix(locs):
    """
    Calculate the Euclidean distance matrix
    :param locs: a torch tensor of size (batch_size, num_nodes, 2) or (num_nodes, 2)
    """
    if len(locs.size()) == 2:
        batch = False
        locs = locs.unsqueeze(0)
    else:
        batch = True


    B, N, _ = locs.size()
    locs_i = locs.unsqueeze(2).expand(B, N, N, 2)
    locs_j = locs.unsqueeze(1).expand(B, N, N, 2)
    distances = torch.sqrt(((locs_i - locs_j) ** 2).sum(dim=-1))

    if not batch:
        distances = distances.squeeze(0)
    
    return distances

# Following 2 functions adapted from https://github.com/nnzhan/Graph-WaveNet/blob/master/util.py
def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=False):
    def calculate_normalized_laplacian(adj):
        """
        
        """
        adj = sp.coo_matrix(adj)
        d = np.array(adj.sum(1))
        d_inv_sqrt = np.power(d, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
        return normalized_laplacian
    if undirected:
        adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
    L = calculate_normalized_laplacian(adj_mx)
    if lambda_max is None:
        lambda_max, _ = linalg.eigsh(L, 1, which='LM')
        lambda_max = lambda_max[0]
    L = sp.csr_matrix(L)
    M, _ = L.shape
    I = sp.identity(M, format='csr', dtype=L.dtype)
    L = (2 / lambda_max * L) - I
    return L.astype(np.float32).todense()

In [17]:
class Mlp(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Mlp, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 2 * out_dim),
            nn.BatchNorm1d(2 * out_dim),
            nn.ReLU(),
            nn.Linear(2 * out_dim, out_dim),
            nn.BatchNorm1d(out_dim),
        )

    def forward(self, x):
        return self.mlp(x)
    

class LocalLayer(MessagePassing):
    def __init__(self, in_dim, out_dim):
        super(LocalLayer, self).__init__(aggr='mean')
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = Mlp(in_dim, out_dim)

    def forward(self, x, edge_index):
        """
        :param x: node features of shape (num_nodes, in_dim)
        :param edge_index: edge index tensor of shape (2, num_edges)
        """
        N = x.size(0)
        # transform node features
        x = self.linear(x)

        # compute normalization
        row, col = edge_index
        deg = degree(row, N, dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # propagate message
        outputs = self.propagate(edge_index, x=x, norm=norm)

        return outputs
    
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    

class DiffusionLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(DiffusionLayer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

        self.mlp = Mlp(in_dim, out_dim)

        self.k = Parameter(torch.Tensor(1, out_dim))

    def forward(self, locs, features):
        """
        :param locs: (B, N, 2)
        :param features: (B, N, in_dim)
        """
        B, N, D = features.size()

        # apply mlp to features
        features = rearrange(features, 'b n d -> (b n) d')
        features = self.mlp(features)
        features = rearrange(features, '(b n) d -> b n d', b=B)

        # calculate the scaled laplacian
        L = np.zeros((B, N, N))
        for b in range(B):
            D = get_distance_matrix(locs[b])
            W = 1 / (D + 1e-8)
            # Set diagonal elements to 0
            mask = torch.ones(N, N).to(features.device) - torch.eye(N).to(features.device)
            W = W * mask
            L_temp = calculate_scaled_laplacian(W.cpu().numpy())
            L[b] = torch.tensor(L_temp).to(features.device)

        # apply diffusion
        L = torch.from_numpy(L).to(features.device).float()   # (B, N, N)
        outputs = torch.bmm(L, features) * self.k

        return outputs
    

class ConvectionLayer(MessagePassing):
    def __init__(self, in_dim, out_dim, edge_dim):
        super(ConvectionLayer, self).__init__(aggr='mean')
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.edge_in_dim = edge_dim

        self.node_linear = Mlp(in_dim, out_dim)
        self.edge_linear = Mlp(edge_dim, out_dim)
        self.message_linear = Mlp(2 * out_dim, out_dim)
        self.update_linear = Mlp(2 * out_dim, out_dim)

    def forward(self, x, edge_index, edge_attr):
        """
        :param x: node features of shape (num_nodes, in_dim)
        :param edge_index: edge index tensor of shape (2, num_edges)
        :param edge_attr: edge features of shape (num_edges, edge_dim)
        """
        assert len(x.size()) == len(edge_attr.size()) == 2, "x and edge_attr must have 2 dimensions"
        assert edge_index.size(1) == edge_attr.size(0), "edge_index and edge_attr must have the same number of edges"

        N = x.size(0)
        # transform node and edge features
        x = self.node_linear(x)
        edge_attr = self.edge_linear(edge_attr)

        outputs = self.propagate(edge_index, x=x, edge_attr=edge_attr)

        return outputs, edge_attr
    
    def message(self, x_i, x_j, edge_attr):
        """
        Edge from j -> i
        """
        x_i = x_i + edge_attr
        x_j = x_j + edge_attr
        x = torch.cat([x_i, x_j], dim=-1)
        return self.message_linear(x) + x_i
    
    def update(self, aggr_out, x):
        x_cat = torch.cat([aggr_out, x], dim=-1)
        return self.update_linear(x_cat) + x
    

class LCDLayer(nn.Module):
    def __init__(self, in_dim, out_dim, edge_dim):
        super(LCDLayer, self).__init__()

        self.local = LocalLayer(in_dim, out_dim)
        self.convection = ConvectionLayer(in_dim, out_dim, edge_dim)
        self.diffusion = DiffusionLayer(in_dim, out_dim)

        self.fusion = nn.Sequential(
            nn.Linear(3 * out_dim, 3),
            nn.Softmax(dim=-1)
        )

    def forward(self, locs, features, winds_feature, edge_index):
        """
        :param locs: a torch tensor of size (B, N, 2)
        :param features: a torch tensor of size (B, N, in_dim)
        :param winds_feature: a torch tensor of size (B, N^2, in_dim)
        :param edge_index: a torch tensor of size (2, N^2)
        """
        B, N, _ = locs.size()
        # preprocess the input data into a batch
        batch = self.batch_preprocess(locs, features, winds_feature, edge_index)

        # input original feature to nn.Module and batch features to MessagePassing
        l_features = self.local(batch.features, batch.edge_index)
        c_features, edge_attr = self.convection(batch.features, batch.edge_index, batch.winds_feature)
        d_features = self.diffusion(locs, features)

        # reverse the batched outputs
        l_features = self.reverse_batch_process(l_features, B=B)
        c_features, edge_attr = self.reverse_batch_process(c_features, edge_attr, B=B)

        # fusion
        weights = self.fusion(torch.cat([l_features, c_features, d_features], dim=-1))   # (B, N, 3)
        outputs = l_features * weights[:, :, 0].unsqueeze(-1) + c_features * weights[:, :, 1].unsqueeze(-1) + d_features * weights[:, :, 2].unsqueeze(-1)

        return outputs, edge_attr


    @staticmethod
    def batch_preprocess(locs, features, winds_feature, edge_index):
        """
        Preprocess the input data into a single batch
        """
        B, N, _ = locs.size()
        graphs = []
        for b in range(B):
            graph = Data(features=features[b],
                         edge_index=edge_index,
                         locs=locs[b],
                         winds_feature=winds_feature[b],
                         num_nodes=N)
            graphs.append(graph)
        graphs = MyGraphDataset(graphs)
        loader = DataLoader(graphs, batch_size=B, shuffle=False)
        return next(iter(loader))
    
    @staticmethod
    def reverse_batch_process(features, edge_attr=None, B=32):
        """
        Reverse the batch process
        :param features: a torch tensor of size (B*N, out_dim)
        :param edge_attr: a torch tensor of size (B*N**2, edge_dim)
        """
        N = features.size(0) // B
        features = rearrange(features, '(b n) d -> b n d', b=B)
        if edge_attr is not None:
            edge_attr = rearrange(edge_attr, '(b m) d -> b m d', b=B)
            return features, edge_attr
        return features


class LCDGCN(nn.Module):
    def __init__(self,
                 in_dim=1,
                 hidden_dim=64,
                 out_dim=1,
                 num_layers=3,
                 edge_dim=3):
        super(LCDGCN, self).__init__()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.num_layers = num_layers
        self.edge_dim = edge_dim

        self.layers = nn.ModuleList()
        self.layers.append(LCDLayer(in_dim, hidden_dim, edge_dim))
        for _ in range(num_layers - 2):
            self.layers.append(LCDLayer(hidden_dim, hidden_dim, hidden_dim))
        self.layers.append(LCDLayer(hidden_dim, out_dim, hidden_dim))

    def forward(self, locs, readings, target_loc, winds_feature, edge_index):
        """
        :param locs: a torch tensor of size (B, N, 2)
        :param readings: a torch tensor of size (B, N, in_dim, 2). Here 2 stands for 2 time steps
        :param target_loc: a torch tensor of size (B, 2)
        :param winds_feature: a torch tensor of size (B, N+1, N+1, edge_dim)
        """
        B, N, D, T = readings.size()
        assert N == winds_feature.size(1) - 1, f"The number of nodes in winds_feature({winds_feature.size(1)}) should be 1 plus that in readings({N})"

        # preprocess edge features
        winds_feature = self.preprocess_edge_features(winds_feature)

        # concat and mask
        lcd_locs = torch.cat([locs, target_loc[:, None, :]], dim=1)   # (B, N+1, 2)
        lcd_readings = torch.cat([readings[:, :, :, 1],
                                  torch.zeros(B, 1, D).to(readings.device)],
                                  dim=1)   # (B, N+1, D)
        
        # forward pass
        for layer in self.layers:
            lcd_readings, winds_feature = layer(lcd_locs, lcd_readings, winds_feature, edge_index)
        outputs = lcd_readings

        return outputs

    @staticmethod
    def preprocess_edge_features(edge_features):
        """
        Transform edge features from shape (B, N, N, D) to (B, N^2, D)
        """
        assert len(edge_features.size()) == 4
        assert edge_features.size(1) == edge_features.size(2)
        B, N, _, D = edge_features.size()
        edge_features = edge_features.view(B, N * N, D)
        return edge_features

# Model Training

In [26]:
hidden_dim = 256
num_layers = 3

# set seed for model initialization
seedi = 45
torch.manual_seed(seedi)
model = LCDGCN(in_dim=1,
               hidden_dim=hidden_dim,
               out_dim=1,
               num_layers=num_layers,
               edge_dim=3)

# set up training
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 100

In [27]:
# set seed for training
seedt = 45
torch.manual_seed(seedt)

# set weight path
save_path = f"./weights/LCDGCN_h{hidden_dim}_l{num_layers}_si{seedi}_st{seedt}.pth"

# training
best_val = 1e10
for epoch in range(epochs):
    model.train()
    losses = []
    for batch in tqdm(train_loader):
        locs, readings, target_loc, target_reading, winds_feature = batch
        readings_input = readings[..., None].permute(0, 2, 3, 1)   # (B, N, 1, 2)
        winds_feature_input = winds_feature[:, :, :, 0, :]   # (B, N+1, N+1, 3)
        pred_target = torch.cat([readings.permute(0, 2, 1)[:, :, 1:],
                                 target_reading[:, 1:].unsqueeze(1)],
                                 dim=1)   # (B, N+1, 1)
        edge_index = get_edge_index(pred_target.size(1))

        # forward
        optimizer.zero_grad()
        pred = model(locs, readings_input, target_loc, winds_feature_input, edge_index)

        # loss
        loss = criterion(pred[:, -1], pred_target[:, -1])
        losses.append(loss.item())

        # backward
        loss.backward()
        optimizer.step()

    # validation
    val_losses = []
    with torch.no_grad():
        model.eval()
        for batch in val_loader:
            locs, readings, target_loc, target_reading, winds_feature = batch
            readings_input = readings[..., None].permute(0, 2, 3, 1)
            winds_feature_input = winds_feature[:, :, :, 0, :]
            pred_target = torch.cat([readings.permute(0, 2, 1)[:, :, 1:],
                                     target_reading[:, 1:].unsqueeze(1)],
                                     dim=1)
            edge_index = get_edge_index(pred_target.size(1))

            # forward
            pred = model(locs, readings_input, target_loc, winds_feature_input, edge_index)

            # loss
            loss = criterion(pred[:, -1], pred_target[:, -1])
            val_losses.append(loss.item())
        
        # print training and validation losses
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {np.mean(losses):.4f}, Validation Loss: {np.mean(val_losses):.4f}")

        # save the best model
        if np.mean(val_losses) < best_val:
            best_val = np.mean(val_losses)
            torch.save(model.state_dict(), save_path)
            print(f"Save the best model at epoch {epoch+1}")

100%|██████████| 2562/2562 [38:31<00:00,  1.11it/s] 


Epoch 1/100, Training Loss: 84.3499, Validation Loss: 118.2162
Save the best model at epoch 1


100%|██████████| 2562/2562 [37:06<00:00,  1.15it/s]


Epoch 2/100, Training Loss: 72.1961, Validation Loss: 133.4166


100%|██████████| 2562/2562 [35:26<00:00,  1.20it/s]


Epoch 3/100, Training Loss: 57.2132, Validation Loss: 92.0953
Save the best model at epoch 3


100%|██████████| 2562/2562 [34:24<00:00,  1.24it/s]


Epoch 4/100, Training Loss: 40.4890, Validation Loss: 83.8644
Save the best model at epoch 4


100%|██████████| 2562/2562 [39:29<00:00,  1.08it/s] 


Epoch 5/100, Training Loss: 30.1686, Validation Loss: 59.8693
Save the best model at epoch 5


100%|██████████| 2562/2562 [1:50:56<00:00,  2.60s/it]     


Epoch 6/100, Training Loss: 24.0600, Validation Loss: 89.2014


100%|██████████| 2562/2562 [33:49<00:00,  1.26it/s]


Epoch 7/100, Training Loss: 20.8710, Validation Loss: 35.7950
Save the best model at epoch 7


100%|██████████| 2562/2562 [34:07<00:00,  1.25it/s]


Epoch 8/100, Training Loss: 19.6122, Validation Loss: 82.6269


100%|██████████| 2562/2562 [34:25<00:00,  1.24it/s]


Epoch 9/100, Training Loss: 18.8339, Validation Loss: 67.3956


100%|██████████| 2562/2562 [34:23<00:00,  1.24it/s]


Epoch 10/100, Training Loss: 18.2112, Validation Loss: 71.8867


100%|██████████| 2562/2562 [34:12<00:00,  1.25it/s]


Epoch 11/100, Training Loss: 17.6261, Validation Loss: 47.8603


100%|██████████| 2562/2562 [33:44<00:00,  1.27it/s]


Epoch 12/100, Training Loss: 17.0283, Validation Loss: 34.0973
Save the best model at epoch 12


100%|██████████| 2562/2562 [34:10<00:00,  1.25it/s]


Epoch 13/100, Training Loss: 16.7674, Validation Loss: 50.7085


100%|██████████| 2562/2562 [34:48<00:00,  1.23it/s]


Epoch 14/100, Training Loss: 16.4371, Validation Loss: 43.8942


100%|██████████| 2562/2562 [34:37<00:00,  1.23it/s]


Epoch 15/100, Training Loss: 16.1615, Validation Loss: 37.9589


100%|██████████| 2562/2562 [34:29<00:00,  1.24it/s]


Epoch 16/100, Training Loss: 15.7449, Validation Loss: 45.7286


100%|██████████| 2562/2562 [34:19<00:00,  1.24it/s]


Epoch 17/100, Training Loss: 15.6610, Validation Loss: 47.9254


100%|██████████| 2562/2562 [34:13<00:00,  1.25it/s]


Epoch 18/100, Training Loss: 15.2748, Validation Loss: 39.9649


100%|██████████| 2562/2562 [34:00<00:00,  1.26it/s]


Epoch 19/100, Training Loss: 15.2563, Validation Loss: 38.1343


100%|██████████| 2562/2562 [33:53<00:00,  1.26it/s]


Epoch 20/100, Training Loss: 14.9470, Validation Loss: 56.6695


100%|██████████| 2562/2562 [34:10<00:00,  1.25it/s]


Epoch 21/100, Training Loss: 14.7973, Validation Loss: 51.4525


100%|██████████| 2562/2562 [34:29<00:00,  1.24it/s]


Epoch 22/100, Training Loss: 14.6412, Validation Loss: 37.8790


100%|██████████| 2562/2562 [34:02<00:00,  1.25it/s]


Epoch 23/100, Training Loss: 14.5304, Validation Loss: 45.0389


100%|██████████| 2562/2562 [34:12<00:00,  1.25it/s]


Epoch 24/100, Training Loss: 14.2392, Validation Loss: 47.7665


100%|██████████| 2562/2562 [34:06<00:00,  1.25it/s]


Epoch 25/100, Training Loss: 14.1232, Validation Loss: 49.5083


100%|██████████| 2562/2562 [34:07<00:00,  1.25it/s]


Epoch 26/100, Training Loss: 14.1156, Validation Loss: 40.0093


100%|██████████| 2562/2562 [1:15:30<00:00,  1.77s/it]  


Epoch 27/100, Training Loss: 13.8735, Validation Loss: 34.6778


100%|██████████| 2562/2562 [35:31<00:00,  1.20it/s]


Epoch 28/100, Training Loss: 13.6924, Validation Loss: 30.8716
Save the best model at epoch 28


100%|██████████| 2562/2562 [36:18<00:00,  1.18it/s]


Epoch 29/100, Training Loss: 13.6654, Validation Loss: 46.5630


100%|██████████| 2562/2562 [35:00<00:00,  1.22it/s]


Epoch 30/100, Training Loss: 13.6129, Validation Loss: 39.5386


100%|██████████| 2562/2562 [35:58<00:00,  1.19it/s]


Epoch 31/100, Training Loss: 13.3589, Validation Loss: 43.6015


100%|██████████| 2562/2562 [35:23<00:00,  1.21it/s]


Epoch 32/100, Training Loss: 13.4133, Validation Loss: 40.4149


100%|██████████| 2562/2562 [35:40<00:00,  1.20it/s]


Epoch 33/100, Training Loss: 13.1808, Validation Loss: 38.0693


100%|██████████| 2562/2562 [37:01<00:00,  1.15it/s]


Epoch 34/100, Training Loss: 13.4672, Validation Loss: 43.9062


100%|██████████| 2562/2562 [38:25<00:00,  1.11it/s]


Epoch 35/100, Training Loss: 13.1662, Validation Loss: 42.0393


100%|██████████| 2562/2562 [38:50<00:00,  1.10it/s]


Epoch 36/100, Training Loss: 12.9831, Validation Loss: 50.0595


100%|██████████| 2562/2562 [36:22<00:00,  1.17it/s]


Epoch 37/100, Training Loss: 12.9189, Validation Loss: 32.1560


100%|██████████| 2562/2562 [1:04:05<00:00,  1.50s/it] 


Epoch 38/100, Training Loss: 12.8622, Validation Loss: 46.0518


100%|██████████| 2562/2562 [2:45:52<00:00,  3.88s/it]     


Epoch 39/100, Training Loss: 12.7521, Validation Loss: 39.4328


100%|██████████| 2562/2562 [2:37:21<00:00,  3.69s/it]     


Epoch 40/100, Training Loss: 12.7487, Validation Loss: 42.7607


100%|██████████| 2562/2562 [2:04:05<00:00,  2.91s/it]    


Epoch 41/100, Training Loss: 12.6906, Validation Loss: 40.4115


100%|██████████| 2562/2562 [3:02:22<00:00,  4.27s/it]    


Epoch 42/100, Training Loss: 12.5558, Validation Loss: 40.5704


100%|██████████| 2562/2562 [3:05:34<00:00,  4.35s/it]     


Epoch 43/100, Training Loss: 12.4821, Validation Loss: 38.4468


100%|██████████| 2562/2562 [33:01<00:00,  1.29it/s]


Epoch 44/100, Training Loss: 12.5049, Validation Loss: 41.3121


100%|██████████| 2562/2562 [34:01<00:00,  1.25it/s]


Epoch 45/100, Training Loss: 12.4312, Validation Loss: 35.1496


100%|██████████| 2562/2562 [34:34<00:00,  1.23it/s]


Epoch 46/100, Training Loss: 12.2788, Validation Loss: 40.1089


100%|██████████| 2562/2562 [34:33<00:00,  1.24it/s]


Epoch 47/100, Training Loss: 12.1918, Validation Loss: 37.7654


100%|██████████| 2562/2562 [34:01<00:00,  1.26it/s]


Epoch 48/100, Training Loss: 12.2281, Validation Loss: 34.4510


100%|██████████| 2562/2562 [34:55<00:00,  1.22it/s]


Epoch 49/100, Training Loss: 12.1507, Validation Loss: 41.8935


100%|██████████| 2562/2562 [35:00<00:00,  1.22it/s]


Epoch 50/100, Training Loss: 12.0280, Validation Loss: 40.5827


100%|██████████| 2562/2562 [35:19<00:00,  1.21it/s]


Epoch 51/100, Training Loss: 11.9421, Validation Loss: 40.7425


100%|██████████| 2562/2562 [35:24<00:00,  1.21it/s]


Epoch 52/100, Training Loss: 11.9439, Validation Loss: 35.4958


100%|██████████| 2562/2562 [35:40<00:00,  1.20it/s]


Epoch 53/100, Training Loss: 11.9891, Validation Loss: 38.5287


100%|██████████| 2562/2562 [35:59<00:00,  1.19it/s]


Epoch 54/100, Training Loss: 11.8715, Validation Loss: 34.1053


100%|██████████| 2562/2562 [7:44:50<00:00, 10.89s/it]     


Epoch 55/100, Training Loss: 11.8793, Validation Loss: 40.1893


100%|██████████| 2562/2562 [32:40<00:00,  1.31it/s]


Epoch 56/100, Training Loss: 11.7221, Validation Loss: 41.0052


100%|██████████| 2562/2562 [33:06<00:00,  1.29it/s]


Epoch 57/100, Training Loss: 11.7074, Validation Loss: 33.6951


100%|██████████| 2562/2562 [32:26<00:00,  1.32it/s]


Epoch 58/100, Training Loss: 11.5768, Validation Loss: 35.2859


100%|██████████| 2562/2562 [37:06<00:00,  1.15it/s]


Epoch 59/100, Training Loss: 11.5906, Validation Loss: 36.2005


100%|██████████| 2562/2562 [44:51<00:00,  1.05s/it] 


Epoch 60/100, Training Loss: 11.5780, Validation Loss: 34.9215


100%|██████████| 2562/2562 [36:05<00:00,  1.18it/s]


Epoch 61/100, Training Loss: 11.5360, Validation Loss: 45.1561


100%|██████████| 2562/2562 [35:14<00:00,  1.21it/s]


Epoch 62/100, Training Loss: 11.4430, Validation Loss: 43.4283


100%|██████████| 2562/2562 [34:56<00:00,  1.22it/s]


Epoch 63/100, Training Loss: 11.3548, Validation Loss: 38.4152


100%|██████████| 2562/2562 [34:48<00:00,  1.23it/s]


Epoch 64/100, Training Loss: 11.3924, Validation Loss: 27.6226
Save the best model at epoch 64


100%|██████████| 2562/2562 [34:43<00:00,  1.23it/s]


Epoch 65/100, Training Loss: 11.4316, Validation Loss: 46.5950


100%|██████████| 2562/2562 [34:33<00:00,  1.24it/s]


Epoch 66/100, Training Loss: 11.1835, Validation Loss: 34.0721


100%|██████████| 2562/2562 [34:23<00:00,  1.24it/s]


Epoch 67/100, Training Loss: 11.2409, Validation Loss: 44.9724


100%|██████████| 2562/2562 [34:12<00:00,  1.25it/s]


Epoch 68/100, Training Loss: 11.1634, Validation Loss: 34.1908


100%|██████████| 2562/2562 [34:04<00:00,  1.25it/s]


Epoch 69/100, Training Loss: 11.1623, Validation Loss: 42.0777


100%|██████████| 2562/2562 [33:50<00:00,  1.26it/s]


Epoch 70/100, Training Loss: 11.1093, Validation Loss: 36.7891


100%|██████████| 2562/2562 [33:45<00:00,  1.26it/s]


Epoch 71/100, Training Loss: 10.9420, Validation Loss: 37.0275


100%|██████████| 2562/2562 [33:37<00:00,  1.27it/s]


Epoch 72/100, Training Loss: 11.0140, Validation Loss: 35.8603


100%|██████████| 2562/2562 [33:48<00:00,  1.26it/s]


Epoch 73/100, Training Loss: 10.9956, Validation Loss: 35.3214


100%|██████████| 2562/2562 [33:45<00:00,  1.27it/s]


Epoch 74/100, Training Loss: 10.9667, Validation Loss: 35.6557


100%|██████████| 2562/2562 [33:41<00:00,  1.27it/s]


Epoch 75/100, Training Loss: 10.8370, Validation Loss: 31.7134


100%|██████████| 2562/2562 [34:10<00:00,  1.25it/s]


Epoch 76/100, Training Loss: 10.7959, Validation Loss: 33.3634


100%|██████████| 2562/2562 [34:27<00:00,  1.24it/s]


Epoch 77/100, Training Loss: 10.8310, Validation Loss: 44.4679


100%|██████████| 2562/2562 [34:57<00:00,  1.22it/s]


Epoch 78/100, Training Loss: 10.6609, Validation Loss: 36.8477


100%|██████████| 2562/2562 [34:59<00:00,  1.22it/s]


Epoch 79/100, Training Loss: 10.6336, Validation Loss: 39.0540


100%|██████████| 2562/2562 [35:16<00:00,  1.21it/s]


Epoch 80/100, Training Loss: 10.6979, Validation Loss: 30.9144


100%|██████████| 2562/2562 [35:31<00:00,  1.20it/s]


Epoch 81/100, Training Loss: 10.5094, Validation Loss: 38.1976


100%|██████████| 2562/2562 [35:58<00:00,  1.19it/s]


Epoch 82/100, Training Loss: 10.6187, Validation Loss: 37.1008


100%|██████████| 2562/2562 [36:22<00:00,  1.17it/s]


Epoch 83/100, Training Loss: 10.5748, Validation Loss: 28.3630


100%|██████████| 2562/2562 [36:43<00:00,  1.16it/s]


Epoch 84/100, Training Loss: 10.4619, Validation Loss: 38.0519


100%|██████████| 2562/2562 [36:35<00:00,  1.17it/s]


Epoch 85/100, Training Loss: 10.4856, Validation Loss: 28.7691


100%|██████████| 2562/2562 [35:55<00:00,  1.19it/s]


Epoch 86/100, Training Loss: 10.4046, Validation Loss: 36.5708


100%|██████████| 2562/2562 [35:48<00:00,  1.19it/s]


Epoch 87/100, Training Loss: 10.3606, Validation Loss: 32.9915


100%|██████████| 2562/2562 [35:44<00:00,  1.19it/s]


Epoch 88/100, Training Loss: 10.3565, Validation Loss: 37.2388


100%|██████████| 2562/2562 [35:44<00:00,  1.19it/s]


Epoch 89/100, Training Loss: 10.3616, Validation Loss: 36.2984


100%|██████████| 2562/2562 [35:28<00:00,  1.20it/s]


Epoch 90/100, Training Loss: 10.2488, Validation Loss: 29.4794


100%|██████████| 2562/2562 [35:42<00:00,  1.20it/s]


Epoch 91/100, Training Loss: 10.1865, Validation Loss: 30.0110


100%|██████████| 2562/2562 [35:54<00:00,  1.19it/s]


Epoch 92/100, Training Loss: 10.1977, Validation Loss: 39.1411


100%|██████████| 2562/2562 [36:07<00:00,  1.18it/s]


Epoch 93/100, Training Loss: 10.1227, Validation Loss: 30.9168


100%|██████████| 2562/2562 [36:57<00:00,  1.16it/s]


Epoch 94/100, Training Loss: 10.1736, Validation Loss: 33.3626


100%|██████████| 2562/2562 [35:36<00:00,  1.20it/s]


Epoch 95/100, Training Loss: 10.0331, Validation Loss: 35.9251


100%|██████████| 2562/2562 [35:45<00:00,  1.19it/s]


Epoch 96/100, Training Loss: 9.8845, Validation Loss: 33.3112


100%|██████████| 2562/2562 [35:49<00:00,  1.19it/s]


Epoch 97/100, Training Loss: 9.9549, Validation Loss: 39.6084


100%|██████████| 2562/2562 [35:46<00:00,  1.19it/s]


Epoch 98/100, Training Loss: 9.9182, Validation Loss: 37.2306


100%|██████████| 2562/2562 [35:11<00:00,  1.21it/s]


Epoch 99/100, Training Loss: 9.9461, Validation Loss: 35.8994


100%|██████████| 2562/2562 [34:54<00:00,  1.22it/s]


Epoch 100/100, Training Loss: 9.7846, Validation Loss: 38.7947
