In [1]:
!pip install vtk



In [2]:
import numpy as np
import torch
from torch import nn, optim
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import seaborn as sns
import matplotlib.pyplot as plt
import vtk
from vtk import *
from vtk.util.numpy_support import vtk_to_numpy
import random
import os
import sys
import time

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Device running:', device)

Device running: cuda


In [4]:
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.in_features = in_features
        # if enable_dropout:
        #     if not self.is_first:
        #         self.dropout = nn.Dropout(dropout_prob)
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)


    def forward(self, x):
        x = self.linear(x)
        # if self.enable_dropout:
        #     if not self.is_first:
        #         x = self.dropout(x)
        return torch.sin(self.omega_0 * x)

In [5]:
class ResidualSineLayer(nn.Module):
    def __init__(self, features, bias=True, ave_first=False, ave_second=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.features = features
        # if enable_dropout:
        #     self.dropout_1 = nn.Dropout(dropout_prob)
        self.linear_1 = nn.Linear(features, features, bias=bias)
        self.linear_2 = nn.Linear(features, features, bias=bias)
        self.weight_1 = .5 if ave_first else 1
        self.weight_2 = .5 if ave_second else 1

        self.init_weights()


    def init_weights(self):
        with torch.no_grad():
            self.linear_1.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)
            self.linear_2.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)

    def forward(self, input):
        linear_1 = self.linear_1(self.weight_1*input)
        # if self.enable_dropout:
        #     linear_1 = self.dropout_1(linear_1)
        sine_1 = torch.sin(self.omega_0 * linear_1)
        sine_2 = torch.sin(self.omega_0 * self.linear_2(sine_1))
        return self.weight_2*(input+sine_2)

In [6]:
class MyResidualSirenNet(nn.Module):
    def __init__(self, obj):
        super(MyResidualSirenNet, self).__init__()
        # self.enable_dropout = obj['enable_dropout']
        # self.dropout_prob = obj['dropout_prob']
        self.Omega_0=30
        self.n_layers = obj['n_layers']
        self.input_dim = obj['dim']
        self.output_dim = obj['total_vars']
        self.neurons_per_layer = obj['n_neurons']
        self.layers = [self.input_dim]
        for i in range(self.n_layers-1):
            self.layers.append(self.neurons_per_layer)
        self.layers.append(self.output_dim)
        self.net_layers = nn.ModuleList()
        for idx in np.arange(self.n_layers):
            layer_in = self.layers[idx]
            layer_out = self.layers[idx+1]
            ## if not the final layer
            if idx != self.n_layers-1:
                ## if first layer
                if idx==0:
                    self.net_layers.append(SineLayer(layer_in,layer_out,bias=True,is_first=idx==0))
                ## if an intermdeiate layer
                else:
                    self.net_layers.append(ResidualSineLayer(layer_in,bias=True,ave_first=idx>1,ave_second=idx==(self.n_layers-2)))
            ## if final layer
            else:
                final_linear = nn.Linear(layer_in,layer_out)
                ## initialize weights for the final layer
                with torch.no_grad():
                    final_linear.weight.uniform_(-np.sqrt(6 / (layer_in)) / self.Omega_0, np.sqrt(6 / (layer_in)) / self.Omega_0)
                self.net_layers.append(final_linear)

    def forward(self,x):
        for net_layer in self.net_layers:
            x = net_layer(x)
        return x

In [7]:
def size_of_network(n_layers, n_neurons, d_in, d_out, is_residual = True):
    # Adding input layer
    layers = [d_in]
    # layers = [3]

    # Adding hidden layers
    layers.extend([n_neurons]*n_layers)
    # layers = [3, 5, 5, 5]

    # Adding output layer
    layers.append(d_out)
    # layers = [3, 5, 5, 5, 1]

    # Number of steps
    n_layers = len(layers)-1
    # n_layers = 5 - 1 = 4

    n_params = 0

    # np.arange(4) = [0, 1, 2, 3]
    for ndx in np.arange(n_layers):

        # number of neurons in below layer
        layer_in = layers[ndx]

        # number of neurons in above layer
        layer_out = layers[ndx+1]

        # max number of neurons in both the layer
        og_layer_in = max(layer_in,layer_out)

        # if lower layer is the input layer
        # or the upper layer is the output layer
        if ndx==0 or ndx==(n_layers-1):
            # Adding weight corresponding to every neuron for every input neuron
            # Adding bias for every neuron in the upper layer
            n_params += ((layer_in+1)*layer_out)

        else:

            # If the layer is residual then proceed as follows as there will be more weights if residual layer is included
            if is_residual:
                # doubt in the following two lines
                n_params += (layer_in*og_layer_in)+og_layer_in
                n_params += (og_layer_in*layer_out)+layer_out

            # if the layer is non residual then simply add number of weights and biases as follows
            else:
                n_params += ((layer_in+1)*layer_out)
            #
        #
    #

    return n_params

In [8]:
def compute_PSNR(arrgt,arr_recon):
    diff = arrgt - arr_recon
    sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
    snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
    return snr

In [9]:
def srs(numOfPoints, valid_pts, percentage, isMaskPresent, mask_array):

    # getting total number of sampled points
    numberOfSampledPoints = int((valid_pts/100) * percentage)

    # storing corner indices in indices variable
    indices = set()

    # As long as we don't get the required amount of sample points keep finding the random numbers
    while(len(indices) < numberOfSampledPoints):
        rp = random.randint(0, numOfPoints-1)
        if isMaskPresent and mask_array[rp] == 0:
            continue
        indices.add(rp)

    # return indices
    return indices

In [10]:
def findMultiVariatePSNR(var_name, total_vars, actual, pred):
    # print('Printing PSNR')
    tot = 0
    psnr_list = []
    for j in range(total_vars):
        psnr = compute_PSNR(actual[:,j], pred[:,j])
        psnr_list.append(psnr)
        tot += psnr
        print(var_name[j], ' PSNR:', psnr)
    avg_psnr = tot/total_vars
    print('\nAverage psnr : ', avg_psnr)
     #this function is calculating the psnr of final epoch (or whenever it is called) of each variable and then averaging it
     #Thus individual epochs psnr is not calculated

    return psnr_list, avg_psnr

In [11]:
def compute_rmse(actual, predicted):
    mse = np.mean((actual - predicted) ** 2)
    return np.sqrt(mse)

def denormalizeValue(total_vars, to, ref):
    to_arr = np.array(to)
    for i in range(total_vars):
        min_data = np.min(ref[:, i])
        max_data = np.max(ref[:, i])
        to_arr[:, i] = (((to[:, i] * 0.5) + 0.5) * (max_data - min_data)) + min_data
    return to_arr

In [12]:
def makeVTI(data, val, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name, normalizedVersion = False):
    nn_predictions = denormalizeValue(total_vars, n_predictions, val) if not normalizedVersion else n_predictions
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(vti_path + vti_name)
    img = vtkImageData()
    img.CopyStructure(data)
    if not isMaskPresent:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            arr = vtkFloatArray()
            for j in range(n_pts):
                arr.InsertNextValue(temp[j])
            arr.SetName(f)
            img.GetPointData().AddArray(arr)
        # print(img)
        writer.SetInputData(img)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')
    else:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            idx = 0
            arr = vtkFloatArray()
            for j in range(n_pts):
                if(mask_arr[j] == 1):
                    arr.InsertNextValue(temp[idx])
                    idx += 1
                else:
                    arr.InsertNextValue(0.0)
            arr.SetName('p_' + f)
            data.GetPointData().AddArray(arr)
        # print(data)
        writer.SetInputData(data)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')

In [13]:
def getImageData(actual_img, val, n_pts, var_name, isMaskPresent, mask_arr):
    img = vtkImageData()
    img.CopyStructure(actual_img)
    # if isMaskPresent:
    #     img.DeepCopy(actual_img)
    # img.SetDimensions(dim)
    # img.SetOrigin(actual_img.GetOrigin())
    # img.SetSpacing(actual_img.GetSpacing())
    if not isMaskPresent:
        f = var_name
        data = val
        arr = vtkFloatArray()
        for j in range(n_pts):
            arr.InsertNextValue(data[j])
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    else:
        f = var_name
        data = val
        idx = 0
        arr = vtkFloatArray()
        for j in range(n_pts):
            if(mask_arr[j] == 1):
                arr.InsertNextValue(data[idx])
                idx += 1
            else:
                arr.InsertNextValue(0.0)
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    return img

In [26]:
from argparse import Namespace

# Parameters (simulating argparse in a Jupyter Notebook)
args = Namespace(
    n_neurons=150,
    n_layers=6,
    epochs=400,  # Required argument: Set the number of epochs
    batchsize=512,
    lr=0.00005,
    no_decay=False,
    decay_rate=0.8,
    decay_at_interval=True,
    decay_interval=20,
    datapath='/content/Teardrop_Gaussian.vti',  # Required: Set the path to your data
    outpath='./models/',
    exp_path='../logs/',
    modified_data_path='./data/',
    dataset_name='3d_data',  # Required: Set the dataset name
    vti_name='predicted_vti',  # Required: Name of the dataset
    vti_path='./data/pred.vti'
)

print(args, end='\n\n')

# Assigning parameters to variables
LR = args.lr
BATCH_SIZE = args.batchsize
decay_rate = args.decay_rate
decay_at_equal_interval = args.decay_at_interval

decay = not args.no_decay
MAX_EPOCH = args.epochs

n_neurons = args.n_neurons
n_layers = args.n_layers + 2
decay_interval = args.decay_interval
outpath = args.outpath
exp_path = args.exp_path
datapath = args.datapath
modified_data_path = args.modified_data_path
dataset_name = args.dataset_name
vti_name = args.vti_name
vti_path = args.vti_path

# Displaying the final configuration
print(f"Learning Rate: {LR}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Decay Rate: {decay_rate}")
print(f"Max Epochs: {MAX_EPOCH}")
print(f"Number of Neurons per Layer: {n_neurons}")
print(f"Number of Layers (including input/output): {n_layers}")
print(f"Data Path: {datapath}")
print(f"Output Path: {outpath}")
print(f"Dataset Name: {dataset_name}")
print(f"Vti Name: {vti_name}")

Namespace(n_neurons=150, n_layers=6, epochs=400, batchsize=512, lr=5e-05, no_decay=False, decay_rate=0.8, decay_at_interval=True, decay_interval=20, datapath='/content/Teardrop_Gaussian.vti', outpath='./models/', exp_path='../logs/', modified_data_path='./data/', dataset_name='3d_data', vti_name='predicted_vti', vti_path='./data/pred.vti')

Learning Rate: 5e-05
Batch Size: 512
Decay Rate: 0.8
Max Epochs: 400
Number of Neurons per Layer: 150
Number of Layers (including input/output): 8
Data Path: /content/Teardrop_Gaussian.vti
Output Path: ./models/
Dataset Name: 3d_data
Vti Name: predicted_vti


In [27]:
# Variable Initialization
var_name = []
total_vars = None  # Number of variables
univariate = None  # True if dataset has one variable, else False
group_size = 5000  # Group size during testing


# Constructing the log file name
log_file = (
    f'train_{dataset_name}_{n_layers-2}rb_{n_neurons}n_{BATCH_SIZE}bs_'
    f'{LR}lr_{decay}decay_{decay_rate}dr_'
    f'{"decayingAtInterval" + str(decay_interval) if decay_at_equal_interval else "decayingWhenLossIncr"}'
)

print(log_file)

train_3d_data_6rb_150n_512bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval20


In [28]:
n_pts = None  # Number of points in the dataset
n_dim = None  # Dimensionality of the data
dim = None  # Other dimension-specific information

print("Decay:", decay)
print(f'Extracting variables from path: {datapath}', end="\n\n")

# Placeholder for data
data_array = []
scalar_data = None

Decay: True
Extracting variables from path: /content/Teardrop_Gaussian.vti



In [29]:

# # Reading values from .vti files
# reader = vtk.vtkXMLImageDataReader()
# reader.SetFileName(datapath)
# reader.Update()

# data = reader.GetOutput()
# scalar_data = data
# pdata = data.GetPointData()
# n_pts = data.GetNumberOfPoints()
# dim = data.GetDimensions()
# n_dim = len(dim)
# total_arr = pdata.GetNumberOfArrays()

# print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

# mask_arr = []
# valid_pts = 0
# var_name = []
# data_array = []

# # Extracting data from the .vti file
# for i in range(total_arr):
#     a_name = pdata.GetArrayName(i)
#     if a_name in ['vtkValidPointMask', 'Swirl']:
#         continue

#     cur_arr = pdata.GetArray(a_name)
#     n_components = cur_arr.GetNumberOfComponents()

#     if n_components == 1:
#         var_name.append(a_name)
#         data_array.append(vtk_to_numpy(cur_arr))
#     else:
#         component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
#         var_name.extend(component_names)
#         for c in range(n_components):
#             c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
#             data_array.append(np.array(c_data))

# valid_pts = n_pts  # Assume all points are valid for simplicity
# total_vars = len(var_name)
# univariate = total_vars == 1

# # Prepare numpy arrays for coordinates and variable values
# cord = np.zeros((valid_pts, n_dim))
# val = np.zeros((valid_pts, total_vars))

# # Store data in numpy arrays
# for i in range(n_pts):
#     pt = scalar_data.GetPoint(i)
#     cord[i, :] = pt
#     val[i, :] = [arr[i] for arr in data_array]

# # Display final information
# print("Total Variables:", total_vars)
# print("Univariate:", univariate)
# print("Coordinates Shape:", cord.shape)
# print("Values Shape:", val.shape)

# Reading values from .vti files
reader = vtk.vtkXMLImageDataReader()
reader.SetFileName(datapath)
reader.Update()

data = reader.GetOutput()
scalar_data = data
pdata = data.GetPointData()
n_pts = data.GetNumberOfPoints()
dim = data.GetDimensions()
n_dim = len(dim)
total_arr = pdata.GetNumberOfArrays()

print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

var_name = []
data_array = []

# Extracting data from the .vti file
for i in range(total_arr):
    a_name = pdata.GetArrayName(i)

    cur_arr = pdata.GetArray(a_name)
    n_components = cur_arr.GetNumberOfComponents()

    if n_components == 1:
        var_name.append(a_name)
        data_array.append(vtk_to_numpy(cur_arr))
    else:
        component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
        var_name.extend(component_names)
        for c in range(n_components):
            c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
            data_array.append(np.array(c_data))

total_vars = len(var_name)
univariate = total_vars == 1

# Prepare numpy arrays for coordinates and variable values
cord = np.zeros((n_pts, n_dim))
val = np.zeros((n_pts, total_vars))

# Store data in numpy arrays
for i in range(n_pts):
    pt = scalar_data.GetPoint(i)
    cord[i, :] = pt
    val[i, :] = [arr[i] for arr in data_array]

# Display final information
print("Total Variables:", total_vars)
print("Univariate:", univariate)
print("Coordinates Shape:", cord.shape)
print("Values Shape:", val.shape)

n_pts: 262144 dim: (64, 64, 64) n_dim: 3 total_arr: 2
Total Variables: 2
Univariate: False
Coordinates Shape: (262144, 3)
Values Shape: (262144, 2)


In [30]:

# # Ensure modified data path exists
# if not os.path.exists(modified_data_path):
#     os.mkdir(modified_data_path)

# Save raw coordinates and values
# np.save(f'{modified_data_path}cord.npy', cord)
# np.save(f'{modified_data_path}val.npy', val)

# # Create copies of non-normalized data
# nn_cord = cord.copy()
# nn_val = val.copy()

# === Separate Normalization for Values ===
# We assume the variable order is:
#   - Means: indices 0,1,2
#   - Std Devs: indices 3,4,5
#   - Weights: indices 6,7,8

# # We'll store normalization parameters so that we can invert normalization later.
# norm_params = {}
# epsilon = 1e-8  # to avoid log(0)

# # Normalize Means to [-1,1] using min–max normalization
# for i in range(3):
#     min_val = np.min(val[:, i])
#     max_val = np.max(val[:, i])
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((val[:, i] - min_val) / (max_val - min_val) - 0.5)

# # Normalize Std Devs: first take log, then min–max to [-1,1]
# for i in range(3, 6):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# # Normalize Weights: take log, then min–max to [-1,1]
# for i in range(6, 9):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# norm_params = {}
real_data=val.copy()
print(real_data)
# Normalize values between -1 and 1
for i in range(total_vars):
    min_data = np.min(val[:, i])
    max_data = np.max(val[:, i])
    # norm_params[var_name[i]] = (min_data, max_data)
    val[:, i] = 2.0 * ((val[:, i] - min_data) / (max_data - min_data) - 0.5)

# Normalize Coordinates to [-1,1]
for i in range(n_dim):
    # Use (dim[i]-1] so that coordinates go from 0 to dim[i]-1.
    cord[:, i] = 2.0 * (cord[:, i] / (dim[i] - 1) - 0.5)

# # Normalize coordinates between 0 and 1
# for i in range(n_dim):
#     cord[:, i] = cord[:, i] / dim[i]


# # Save normalized values and coordinates
# np.save(f'{modified_data_path}n_cord.npy', cord)
# np.save(f'{modified_data_path}n_val.npy', val)
n_cord = cord.copy()
n_val = val.copy()

# # Reload data for verification
# n_cord = np.load(f'{modified_data_path}n_cord.npy')
# n_val = np.load(f'{modified_data_path}n_val.npy')
# cord = np.load(f'{modified_data_path}cord.npy')
# val = np.load(f'{modified_data_path}val.npy')

# Convert normalized data to PyTorch tensors
torch_coords = torch.from_numpy(n_cord)
torch_vals = torch.from_numpy(n_val)

# Display dataset details
print('Dataset Name:', dataset_name)
print('Total Variables:', total_vars)
print('Variables Name:', var_name, end="\n\n")
print('Total Points in Data:', n_pts)
print('Dimension of the Dataset:', dim)
print('Number of Dimensions:', n_dim)
print('Coordinate Tensor Shape:', torch_coords.shape)
print('Scalar Values Tensor Shape:', torch_vals.shape)

print('\n###### Data setup is complete, now starting training ######\n')

[[7.71155238e-01 1.39814585e-01]
 [2.51629448e+00 6.46227449e-02]
 [4.82296085e+00 5.97186051e-02]
 ...
 [8.69416428e+01 4.36629914e-02]
 [8.45070496e+01 7.31490031e-02]
 [8.21956940e+01 1.30982995e-01]]
Dataset Name: 3d_data
Total Variables: 2
Variables Name: ['Average', 'Standard_Deviation']

Total Points in Data: 262144
Dimension of the Dataset: (64, 64, 64)
Number of Dimensions: 3
Coordinate Tensor Shape: torch.Size([262144, 3])
Scalar Values Tensor Shape: torch.Size([262144, 2])

###### Data setup is complete, now starting training ######



In [31]:
# Prepare the DataLoader
train_dataloader = DataLoader(
    TensorDataset(torch_coords, torch_vals),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    num_workers=4
)

# Model configuration
obj = {
    'total_vars': total_vars,
    'dim': n_dim,
    'n_neurons': n_neurons,
    'n_layers': n_layers
}

# Initialize the model, optimizer, and loss function
model = MyResidualSirenNet(obj).to(device)
print(model)

optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999))
print(optimizer)

criterion = nn.MSELoss()
print(criterion)

# Training configuration summary
print('\nLearning Rate:', LR)
print('Max Epochs:', MAX_EPOCH)
print('Batch Size:', BATCH_SIZE)
print('Number of Hidden Layers:', obj['n_layers'] - 2)
print('Number of Neurons per Layer:', obj['n_neurons'])

if decay:
    print('Decay Rate:', decay_rate)
    if decay_at_equal_interval:
        print(f'Rate decays every {decay_interval} epochs.')
    else:
        print('Rate decays when the current epoch loss is greater than the previous epoch loss.')
else:
    print('No decay!')
print()

MyResidualSirenNet(
  (net_layers): ModuleList(
    (0): SineLayer(
      (linear): Linear(in_features=3, out_features=150, bias=True)
    )
    (1-6): 6 x ResidualSineLayer(
      (linear_1): Linear(in_features=150, out_features=150, bias=True)
      (linear_2): Linear(in_features=150, out_features=150, bias=True)
    )
    (7): Linear(in_features=150, out_features=2, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 5e-05
    maximize: False
    weight_decay: 0
)
MSELoss()

Learning Rate: 5e-05
Max Epochs: 400
Batch Size: 512
Number of Hidden Layers: 6
Number of Neurons per Layer: 150
Decay Rate: 0.8
Rate decays every 20 epochs.



In [33]:
train_loss_list = []
best_epoch = -1
best_loss = 1e8
from tqdm import tqdm
# Ensure the output path exists
if not os.path.exists(outpath):
    os.makedirs(outpath)

# Training loop
for epoch in tqdm(range(MAX_EPOCH)):
    model.train()
    temp_loss_list = []
    start = time.time()

    # Batch-by-batch training
    for X_train, y_train in train_dataloader:
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.float32).to(device)

        if univariate:
            y_train = y_train.squeeze()

        optimizer.zero_grad()
        predictions = model(X_train)
        predictions = predictions.squeeze()
        loss = criterion(predictions, y_train)
        loss.backward()
        optimizer.step()

        # Track batch loss
        temp_loss_list.append(loss.detach().cpu().numpy())

    # Calculate epoch loss
    epoch_loss = np.average(temp_loss_list)

    # Learning rate decay
    if decay:
        if decay_at_equal_interval:
            if epoch >= decay_interval and epoch % decay_interval == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= decay_rate
        # else:
        #     if epoch > 0 and epoch_loss > train_loss_list[-1]:
        #         for param_group in optimizer.param_groups:
        #             param_group['lr'] *= decay_rate
        if epoch > 0 and epoch_loss > train_loss_list[-1]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_rate

    # Track losses and best model
    train_loss_list.append(epoch_loss)
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch+1

    end = time.time()
    print(
        f"Epoch: {epoch + 1}/{MAX_EPOCH} | Train Loss: {train_loss_list[-1]} | "
        f"Time: {round(end - start, 2)}s ({device}) | LR: {optimizer.param_groups[0]['lr']}"
    )

    # Save model at intervals
    if (epoch + 1) % 50 == 0:
        model_name = (
            f'train_{dataset_name}_{epoch + 1}ep_{n_layers - 2}rb_{n_neurons}n_'
            f'{BATCH_SIZE}bs_{LR}lr_{decay}decay_{decay_rate}dr_'
            f'{"decayingAtInterval" + str(decay_interval) if decay_at_equal_interval else "decayingWhenLossIncr"}'
        )
        torch.save(
            {"epoch": epoch + 1, "model_state_dict": model.state_dict()},
            os.path.join(outpath, f'{model_name}.pth')
        )

# Final summary
print('\nEpoch with Least Loss:', best_epoch, '| Loss:', best_loss, '\n')

# Save the final model
model_name = f'siren_compressor'
torch.save(
    {"epoch": MAX_EPOCH, "model_state_dict": model.state_dict()},
    os.path.join(outpath, f'{model_name}.pth')
)

  0%|          | 1/400 [00:04<30:52,  4.64s/it]

Epoch: 1/400 | Train Loss: 0.0015109451487660408 | Time: 4.64s (cuda) | LR: 5e-05


  0%|          | 2/400 [00:10<34:12,  5.16s/it]

Epoch: 2/400 | Train Loss: 0.001430923119187355 | Time: 5.52s (cuda) | LR: 5e-05


  1%|          | 3/400 [00:14<32:16,  4.88s/it]

Epoch: 3/400 | Train Loss: 0.0013596420176327229 | Time: 4.54s (cuda) | LR: 5e-05


  1%|          | 4/400 [00:19<31:02,  4.70s/it]

Epoch: 4/400 | Train Loss: 0.0012895341496914625 | Time: 4.43s (cuda) | LR: 5e-05


  1%|▏         | 5/400 [00:24<32:29,  4.94s/it]

Epoch: 5/400 | Train Loss: 0.0012349775061011314 | Time: 5.34s (cuda) | LR: 5e-05


  2%|▏         | 6/400 [00:28<31:17,  4.77s/it]

Epoch: 6/400 | Train Loss: 0.0011870970483869314 | Time: 4.44s (cuda) | LR: 5e-05


  2%|▏         | 7/400 [00:34<31:57,  4.88s/it]

Epoch: 7/400 | Train Loss: 0.001140365726314485 | Time: 5.11s (cuda) | LR: 5e-05


  2%|▏         | 8/400 [00:38<31:23,  4.81s/it]

Epoch: 8/400 | Train Loss: 0.0011093220673501492 | Time: 4.65s (cuda) | LR: 5e-05


  2%|▏         | 9/400 [00:43<30:41,  4.71s/it]

Epoch: 9/400 | Train Loss: 0.0010674785589799285 | Time: 4.49s (cuda) | LR: 5e-05


  2%|▎         | 10/400 [00:48<31:50,  4.90s/it]

Epoch: 10/400 | Train Loss: 0.0010389942908659577 | Time: 5.32s (cuda) | LR: 5e-05


  3%|▎         | 11/400 [00:52<30:48,  4.75s/it]

Epoch: 11/400 | Train Loss: 0.000993955647572875 | Time: 4.41s (cuda) | LR: 5e-05


  3%|▎         | 12/400 [00:59<34:55,  5.40s/it]

Epoch: 12/400 | Train Loss: 0.0009818655671551824 | Time: 6.89s (cuda) | LR: 5e-05


  3%|▎         | 13/400 [01:04<32:59,  5.12s/it]

Epoch: 13/400 | Train Loss: 0.0009569481480866671 | Time: 4.46s (cuda) | LR: 5e-05


  4%|▎         | 14/400 [01:08<31:31,  4.90s/it]

Epoch: 14/400 | Train Loss: 0.0009353189379908144 | Time: 4.4s (cuda) | LR: 5e-05


  4%|▍         | 15/400 [01:14<32:31,  5.07s/it]

Epoch: 15/400 | Train Loss: 0.0009048109641298652 | Time: 5.46s (cuda) | LR: 5e-05


  4%|▍         | 16/400 [01:18<31:08,  4.87s/it]

Epoch: 16/400 | Train Loss: 0.0008955832454375923 | Time: 4.4s (cuda) | LR: 5e-05


  4%|▍         | 17/400 [01:23<30:56,  4.85s/it]

Epoch: 17/400 | Train Loss: 0.0008766036480665207 | Time: 4.81s (cuda) | LR: 5e-05


  4%|▍         | 18/400 [01:28<30:57,  4.86s/it]

Epoch: 18/400 | Train Loss: 0.00085640192264691 | Time: 4.89s (cuda) | LR: 5e-05


  5%|▍         | 19/400 [01:32<29:52,  4.70s/it]

Epoch: 19/400 | Train Loss: 0.0008555861422792077 | Time: 4.33s (cuda) | LR: 5e-05


  5%|▌         | 20/400 [01:37<31:01,  4.90s/it]

Epoch: 20/400 | Train Loss: 0.0008389002759940922 | Time: 5.35s (cuda) | LR: 5e-05


  5%|▌         | 21/400 [01:42<30:04,  4.76s/it]

Epoch: 21/400 | Train Loss: 0.000831748591735959 | Time: 4.44s (cuda) | LR: 4e-05


  6%|▌         | 22/400 [01:46<29:30,  4.68s/it]

Epoch: 22/400 | Train Loss: 0.0006646616966463625 | Time: 4.5s (cuda) | LR: 4e-05


  6%|▌         | 23/400 [01:52<30:46,  4.90s/it]

Epoch: 23/400 | Train Loss: 0.000616923556663096 | Time: 5.4s (cuda) | LR: 4e-05


  6%|▌         | 24/400 [01:56<29:43,  4.74s/it]

Epoch: 24/400 | Train Loss: 0.0006139766192063689 | Time: 4.38s (cuda) | LR: 4e-05


  6%|▋         | 25/400 [02:01<30:28,  4.88s/it]

Epoch: 25/400 | Train Loss: 0.0006168707623146474 | Time: 5.18s (cuda) | LR: 3.2000000000000005e-05


  6%|▋         | 26/400 [02:06<29:47,  4.78s/it]

Epoch: 26/400 | Train Loss: 0.0004946518456563354 | Time: 4.56s (cuda) | LR: 3.2000000000000005e-05


  7%|▋         | 27/400 [02:10<28:56,  4.66s/it]

Epoch: 27/400 | Train Loss: 0.0004573682672344148 | Time: 4.36s (cuda) | LR: 3.2000000000000005e-05


  7%|▋         | 28/400 [02:16<30:30,  4.92s/it]

Epoch: 28/400 | Train Loss: 0.00045074959052726626 | Time: 5.54s (cuda) | LR: 3.2000000000000005e-05


  7%|▋         | 29/400 [02:20<29:23,  4.75s/it]

Epoch: 29/400 | Train Loss: 0.00045675228466279805 | Time: 4.36s (cuda) | LR: 2.5600000000000006e-05


  8%|▊         | 30/400 [02:25<28:38,  4.65s/it]

Epoch: 30/400 | Train Loss: 0.00036327383713796735 | Time: 4.39s (cuda) | LR: 2.5600000000000006e-05


  8%|▊         | 31/400 [02:30<29:51,  4.86s/it]

Epoch: 31/400 | Train Loss: 0.00033373068436048925 | Time: 5.34s (cuda) | LR: 2.5600000000000006e-05


  8%|▊         | 32/400 [02:34<28:53,  4.71s/it]

Epoch: 32/400 | Train Loss: 0.00032952011679299176 | Time: 4.37s (cuda) | LR: 2.5600000000000006e-05


  8%|▊         | 33/400 [02:40<29:57,  4.90s/it]

Epoch: 33/400 | Train Loss: 0.00032696404377929866 | Time: 5.34s (cuda) | LR: 2.5600000000000006e-05


  8%|▊         | 34/400 [02:44<29:06,  4.77s/it]

Epoch: 34/400 | Train Loss: 0.0003294984344393015 | Time: 4.47s (cuda) | LR: 2.0480000000000007e-05


  9%|▉         | 35/400 [02:48<28:20,  4.66s/it]

Epoch: 35/400 | Train Loss: 0.0002583148016128689 | Time: 4.39s (cuda) | LR: 2.0480000000000007e-05


  9%|▉         | 36/400 [02:54<29:28,  4.86s/it]

Epoch: 36/400 | Train Loss: 0.00023537920787930489 | Time: 5.33s (cuda) | LR: 2.0480000000000007e-05


  9%|▉         | 37/400 [02:58<28:26,  4.70s/it]

Epoch: 37/400 | Train Loss: 0.00023374591546598822 | Time: 4.33s (cuda) | LR: 2.0480000000000007e-05


 10%|▉         | 38/400 [03:03<28:16,  4.69s/it]

Epoch: 38/400 | Train Loss: 0.00023726672225166112 | Time: 4.65s (cuda) | LR: 1.6384000000000008e-05


 10%|▉         | 39/400 [03:08<28:44,  4.78s/it]

Epoch: 39/400 | Train Loss: 0.00019308601622469723 | Time: 4.98s (cuda) | LR: 1.6384000000000008e-05


 10%|█         | 40/400 [03:12<27:57,  4.66s/it]

Epoch: 40/400 | Train Loss: 0.00017504199058748782 | Time: 4.38s (cuda) | LR: 1.6384000000000008e-05


 10%|█         | 41/400 [03:18<29:14,  4.89s/it]

Epoch: 41/400 | Train Loss: 0.0001773138646967709 | Time: 5.42s (cuda) | LR: 1.0485760000000006e-05


 10%|█         | 42/400 [03:22<28:12,  4.73s/it]

Epoch: 42/400 | Train Loss: 0.0001303690078202635 | Time: 4.35s (cuda) | LR: 1.0485760000000006e-05


 11%|█         | 43/400 [03:26<27:28,  4.62s/it]

Epoch: 43/400 | Train Loss: 0.00011573566735023633 | Time: 4.36s (cuda) | LR: 1.0485760000000006e-05


 11%|█         | 44/400 [03:32<28:36,  4.82s/it]

Epoch: 44/400 | Train Loss: 0.00011595742398640141 | Time: 5.29s (cuda) | LR: 8.388608000000005e-06


 11%|█▏        | 45/400 [03:36<27:55,  4.72s/it]

Epoch: 45/400 | Train Loss: 0.00010152315371669829 | Time: 4.48s (cuda) | LR: 8.388608000000005e-06


 12%|█▏        | 46/400 [03:41<28:25,  4.82s/it]

Epoch: 46/400 | Train Loss: 9.633637091610581e-05 | Time: 5.04s (cuda) | LR: 8.388608000000005e-06


 12%|█▏        | 47/400 [03:46<28:19,  4.81s/it]

Epoch: 47/400 | Train Loss: 9.548530215397477e-05 | Time: 4.8s (cuda) | LR: 8.388608000000005e-06


 12%|█▏        | 48/400 [03:50<27:37,  4.71s/it]

Epoch: 48/400 | Train Loss: 9.375203808303922e-05 | Time: 4.47s (cuda) | LR: 8.388608000000005e-06


 12%|█▏        | 49/400 [03:56<28:30,  4.87s/it]

Epoch: 49/400 | Train Loss: 9.187936666421592e-05 | Time: 5.25s (cuda) | LR: 8.388608000000005e-06


 12%|█▎        | 50/400 [04:00<27:26,  4.71s/it]

Epoch: 50/400 | Train Loss: 8.971226634457707e-05 | Time: 4.31s (cuda) | LR: 8.388608000000005e-06


 13%|█▎        | 51/400 [04:04<26:50,  4.61s/it]

Epoch: 51/400 | Train Loss: 8.843866817187518e-05 | Time: 4.4s (cuda) | LR: 8.388608000000005e-06


 13%|█▎        | 52/400 [04:10<27:57,  4.82s/it]

Epoch: 52/400 | Train Loss: 8.982128201751038e-05 | Time: 5.3s (cuda) | LR: 6.7108864000000044e-06


 13%|█▎        | 53/400 [04:14<27:07,  4.69s/it]

Epoch: 53/400 | Train Loss: 7.60410912334919e-05 | Time: 4.38s (cuda) | LR: 6.7108864000000044e-06


 14%|█▎        | 54/400 [04:19<28:12,  4.89s/it]

Epoch: 54/400 | Train Loss: 7.145704876165837e-05 | Time: 5.36s (cuda) | LR: 6.7108864000000044e-06


 14%|█▍        | 55/400 [04:24<27:12,  4.73s/it]

Epoch: 55/400 | Train Loss: 7.057008042465895e-05 | Time: 4.36s (cuda) | LR: 6.7108864000000044e-06


 14%|█▍        | 56/400 [04:28<26:25,  4.61s/it]

Epoch: 56/400 | Train Loss: 7.034775626379997e-05 | Time: 4.32s (cuda) | LR: 6.7108864000000044e-06


 14%|█▍        | 57/400 [04:33<27:27,  4.80s/it]

Epoch: 57/400 | Train Loss: 6.953949196031317e-05 | Time: 5.25s (cuda) | LR: 6.7108864000000044e-06


 14%|█▍        | 58/400 [04:38<26:41,  4.68s/it]

Epoch: 58/400 | Train Loss: 6.878306885482743e-05 | Time: 4.4s (cuda) | LR: 6.7108864000000044e-06


 15%|█▍        | 59/400 [04:42<26:20,  4.63s/it]

Epoch: 59/400 | Train Loss: 6.735972419846803e-05 | Time: 4.52s (cuda) | LR: 6.7108864000000044e-06


 15%|█▌        | 60/400 [04:48<27:18,  4.82s/it]

Epoch: 60/400 | Train Loss: 6.636633770540357e-05 | Time: 5.24s (cuda) | LR: 6.7108864000000044e-06


 15%|█▌        | 61/400 [04:52<26:32,  4.70s/it]

Epoch: 61/400 | Train Loss: 6.548961391672492e-05 | Time: 4.42s (cuda) | LR: 5.368709120000004e-06


 16%|█▌        | 62/400 [04:57<27:27,  4.87s/it]

Epoch: 62/400 | Train Loss: 5.750222044298425e-05 | Time: 5.28s (cuda) | LR: 5.368709120000004e-06


 16%|█▌        | 63/400 [05:02<26:33,  4.73s/it]

Epoch: 63/400 | Train Loss: 5.520549893844873e-05 | Time: 4.39s (cuda) | LR: 5.368709120000004e-06


 16%|█▌        | 64/400 [05:06<25:55,  4.63s/it]

Epoch: 64/400 | Train Loss: 5.482362030306831e-05 | Time: 4.4s (cuda) | LR: 5.368709120000004e-06


 16%|█▋        | 65/400 [05:11<26:48,  4.80s/it]

Epoch: 65/400 | Train Loss: 5.461931141326204e-05 | Time: 5.2s (cuda) | LR: 5.368709120000004e-06


 16%|█▋        | 66/400 [05:16<25:56,  4.66s/it]

Epoch: 66/400 | Train Loss: 5.39950342499651e-05 | Time: 4.33s (cuda) | LR: 5.368709120000004e-06


 17%|█▋        | 67/400 [05:20<26:12,  4.72s/it]

Epoch: 67/400 | Train Loss: 5.3423893405124545e-05 | Time: 4.86s (cuda) | LR: 5.368709120000004e-06


 17%|█▋        | 68/400 [05:25<26:08,  4.73s/it]

Epoch: 68/400 | Train Loss: 5.285254155751318e-05 | Time: 4.73s (cuda) | LR: 5.368709120000004e-06


 17%|█▋        | 69/400 [05:29<25:25,  4.61s/it]

Epoch: 69/400 | Train Loss: 5.246912769507617e-05 | Time: 4.33s (cuda) | LR: 5.368709120000004e-06


 18%|█▊        | 70/400 [05:35<26:21,  4.79s/it]

Epoch: 70/400 | Train Loss: 5.179506479180418e-05 | Time: 5.23s (cuda) | LR: 5.368709120000004e-06


 18%|█▊        | 71/400 [05:39<25:34,  4.66s/it]

Epoch: 71/400 | Train Loss: 5.1257207815069705e-05 | Time: 4.36s (cuda) | LR: 5.368709120000004e-06


 18%|█▊        | 72/400 [05:43<25:00,  4.57s/it]

Epoch: 72/400 | Train Loss: 5.053350469097495e-05 | Time: 4.37s (cuda) | LR: 5.368709120000004e-06


 18%|█▊        | 73/400 [05:49<26:12,  4.81s/it]

Epoch: 73/400 | Train Loss: 5.012039764551446e-05 | Time: 5.35s (cuda) | LR: 5.368709120000004e-06


 18%|█▊        | 74/400 [05:53<25:27,  4.68s/it]

Epoch: 74/400 | Train Loss: 4.955410986440256e-05 | Time: 4.4s (cuda) | LR: 5.368709120000004e-06


 19%|█▉        | 75/400 [05:58<25:48,  4.76s/it]

Epoch: 75/400 | Train Loss: 4.900166095467284e-05 | Time: 4.95s (cuda) | LR: 5.368709120000004e-06


 19%|█▉        | 76/400 [06:03<25:30,  4.72s/it]

Epoch: 76/400 | Train Loss: 4.854093276662752e-05 | Time: 4.62s (cuda) | LR: 5.368709120000004e-06


 19%|█▉        | 77/400 [06:07<24:49,  4.61s/it]

Epoch: 77/400 | Train Loss: 4.841479312744923e-05 | Time: 4.35s (cuda) | LR: 5.368709120000004e-06


 20%|█▉        | 78/400 [06:12<25:47,  4.80s/it]

Epoch: 78/400 | Train Loss: 4.758798604598269e-05 | Time: 5.25s (cuda) | LR: 5.368709120000004e-06


 20%|█▉        | 79/400 [06:17<24:57,  4.66s/it]

Epoch: 79/400 | Train Loss: 4.7154804633464664e-05 | Time: 4.33s (cuda) | LR: 5.368709120000004e-06


 20%|██        | 80/400 [06:21<24:34,  4.61s/it]

Epoch: 80/400 | Train Loss: 4.684032319346443e-05 | Time: 4.48s (cuda) | LR: 5.368709120000004e-06


 20%|██        | 81/400 [06:26<25:33,  4.81s/it]

Epoch: 81/400 | Train Loss: 4.6536806621588767e-05 | Time: 5.27s (cuda) | LR: 4.294967296000004e-06


 20%|██        | 82/400 [06:31<24:40,  4.66s/it]

Epoch: 82/400 | Train Loss: 4.1174083889927715e-05 | Time: 4.31s (cuda) | LR: 4.294967296000004e-06


 21%|██        | 83/400 [06:36<25:29,  4.82s/it]

Epoch: 83/400 | Train Loss: 3.959013338317163e-05 | Time: 5.21s (cuda) | LR: 4.294967296000004e-06


 21%|██        | 84/400 [06:41<25:02,  4.75s/it]

Epoch: 84/400 | Train Loss: 3.957794979214668e-05 | Time: 4.59s (cuda) | LR: 4.294967296000004e-06


 21%|██▏       | 85/400 [06:45<24:22,  4.64s/it]

Epoch: 85/400 | Train Loss: 3.9584276237292215e-05 | Time: 4.38s (cuda) | LR: 3.4359738368000033e-06


 22%|██▏       | 86/400 [06:50<25:33,  4.88s/it]

Epoch: 86/400 | Train Loss: 3.569632099242881e-05 | Time: 5.45s (cuda) | LR: 3.4359738368000033e-06


 22%|██▏       | 87/400 [06:55<24:40,  4.73s/it]

Epoch: 87/400 | Train Loss: 3.484329499769956e-05 | Time: 4.36s (cuda) | LR: 3.4359738368000033e-06


 22%|██▏       | 88/400 [06:59<24:01,  4.62s/it]

Epoch: 88/400 | Train Loss: 3.486318382783793e-05 | Time: 4.37s (cuda) | LR: 2.7487790694400027e-06


 22%|██▏       | 89/400 [07:04<25:03,  4.83s/it]

Epoch: 89/400 | Train Loss: 3.2126117730513215e-05 | Time: 5.33s (cuda) | LR: 2.7487790694400027e-06


 22%|██▎       | 90/400 [07:09<24:11,  4.68s/it]

Epoch: 90/400 | Train Loss: 3.159030165988952e-05 | Time: 4.33s (cuda) | LR: 2.7487790694400027e-06


 23%|██▎       | 91/400 [07:14<24:53,  4.83s/it]

Epoch: 91/400 | Train Loss: 3.149168333038688e-05 | Time: 5.18s (cuda) | LR: 2.7487790694400027e-06


 23%|██▎       | 92/400 [07:18<24:13,  4.72s/it]

Epoch: 92/400 | Train Loss: 3.136107625323348e-05 | Time: 4.46s (cuda) | LR: 2.7487790694400027e-06


 23%|██▎       | 93/400 [07:23<23:42,  4.63s/it]

Epoch: 93/400 | Train Loss: 3.1173171009868383e-05 | Time: 4.43s (cuda) | LR: 2.7487790694400027e-06


 24%|██▎       | 94/400 [07:28<24:37,  4.83s/it]

Epoch: 94/400 | Train Loss: 3.105166615569033e-05 | Time: 5.28s (cuda) | LR: 2.7487790694400027e-06


 24%|██▍       | 95/400 [07:32<23:45,  4.67s/it]

Epoch: 95/400 | Train Loss: 3.0950082873459905e-05 | Time: 4.31s (cuda) | LR: 2.7487790694400027e-06


 24%|██▍       | 96/400 [07:37<23:07,  4.56s/it]

Epoch: 96/400 | Train Loss: 3.0673534638481215e-05 | Time: 4.3s (cuda) | LR: 2.7487790694400027e-06


 24%|██▍       | 97/400 [07:43<25:06,  4.97s/it]

Epoch: 97/400 | Train Loss: 3.0507166229654104e-05 | Time: 5.93s (cuda) | LR: 2.7487790694400027e-06


 24%|██▍       | 98/400 [07:47<24:03,  4.78s/it]

Epoch: 98/400 | Train Loss: 3.03201286442345e-05 | Time: 4.32s (cuda) | LR: 2.7487790694400027e-06


 25%|██▍       | 99/400 [07:53<25:02,  4.99s/it]

Epoch: 99/400 | Train Loss: 3.0144630727590993e-05 | Time: 5.49s (cuda) | LR: 2.7487790694400027e-06


 25%|██▌       | 100/400 [07:57<24:02,  4.81s/it]

Epoch: 100/400 | Train Loss: 3.0035478630452417e-05 | Time: 4.37s (cuda) | LR: 2.7487790694400027e-06


 25%|██▌       | 101/400 [08:02<24:10,  4.85s/it]

Epoch: 101/400 | Train Loss: 2.9768390959361568e-05 | Time: 4.96s (cuda) | LR: 2.1990232555520023e-06


 26%|██▌       | 102/400 [08:07<25:09,  5.07s/it]

Epoch: 102/400 | Train Loss: 2.7722506274585612e-05 | Time: 5.56s (cuda) | LR: 2.1990232555520023e-06


 26%|██▌       | 103/400 [08:12<24:57,  5.04s/it]

Epoch: 103/400 | Train Loss: 2.7270521968603134e-05 | Time: 4.99s (cuda) | LR: 2.1990232555520023e-06


 26%|██▌       | 104/400 [08:18<25:40,  5.20s/it]

Epoch: 104/400 | Train Loss: 2.7206675440538675e-05 | Time: 5.58s (cuda) | LR: 2.1990232555520023e-06


 26%|██▋       | 105/400 [08:23<25:27,  5.18s/it]

Epoch: 105/400 | Train Loss: 2.7087877242593095e-05 | Time: 5.12s (cuda) | LR: 2.1990232555520023e-06


 26%|██▋       | 106/400 [08:28<25:04,  5.12s/it]

Epoch: 106/400 | Train Loss: 2.7021507776225917e-05 | Time: 4.97s (cuda) | LR: 2.1990232555520023e-06


 27%|██▋       | 107/400 [08:34<25:51,  5.30s/it]

Epoch: 107/400 | Train Loss: 2.6921834432869218e-05 | Time: 5.71s (cuda) | LR: 2.1990232555520023e-06


 27%|██▋       | 108/400 [08:38<24:51,  5.11s/it]

Epoch: 108/400 | Train Loss: 2.6753539714263752e-05 | Time: 4.67s (cuda) | LR: 2.1990232555520023e-06


 27%|██▋       | 109/400 [08:44<25:53,  5.34s/it]

Epoch: 109/400 | Train Loss: 2.661073267518077e-05 | Time: 5.87s (cuda) | LR: 2.1990232555520023e-06


 28%|██▊       | 110/400 [08:49<24:18,  5.03s/it]

Epoch: 110/400 | Train Loss: 2.647850851644762e-05 | Time: 4.31s (cuda) | LR: 2.1990232555520023e-06


 28%|██▊       | 111/400 [08:55<26:02,  5.41s/it]

Epoch: 111/400 | Train Loss: 2.6393165171612054e-05 | Time: 6.28s (cuda) | LR: 2.1990232555520023e-06


 28%|██▊       | 112/400 [09:00<24:52,  5.18s/it]

Epoch: 112/400 | Train Loss: 2.628159199957736e-05 | Time: 4.66s (cuda) | LR: 2.1990232555520023e-06


 28%|██▊       | 113/400 [09:05<24:41,  5.16s/it]

Epoch: 113/400 | Train Loss: 2.614855475258082e-05 | Time: 5.11s (cuda) | LR: 2.1990232555520023e-06


 28%|██▊       | 114/400 [09:10<25:06,  5.27s/it]

Epoch: 114/400 | Train Loss: 2.601681808300782e-05 | Time: 5.51s (cuda) | LR: 2.1990232555520023e-06


 29%|██▉       | 115/400 [09:15<24:57,  5.26s/it]

Epoch: 115/400 | Train Loss: 2.5891702534863725e-05 | Time: 5.23s (cuda) | LR: 2.1990232555520023e-06


 29%|██▉       | 116/400 [09:21<24:50,  5.25s/it]

Epoch: 116/400 | Train Loss: 2.5761106371646747e-05 | Time: 5.22s (cuda) | LR: 2.1990232555520023e-06


 29%|██▉       | 117/400 [09:25<24:03,  5.10s/it]

Epoch: 117/400 | Train Loss: 2.5621957320254296e-05 | Time: 4.76s (cuda) | LR: 2.1990232555520023e-06


 30%|██▉       | 118/400 [09:30<23:02,  4.90s/it]

Epoch: 118/400 | Train Loss: 2.5524288503220305e-05 | Time: 4.44s (cuda) | LR: 2.1990232555520023e-06


 30%|██▉       | 119/400 [09:36<24:24,  5.21s/it]

Epoch: 119/400 | Train Loss: 2.544107337598689e-05 | Time: 5.93s (cuda) | LR: 2.1990232555520023e-06


 30%|███       | 120/400 [09:40<23:09,  4.96s/it]

Epoch: 120/400 | Train Loss: 2.5327890398330055e-05 | Time: 4.38s (cuda) | LR: 2.1990232555520023e-06


 30%|███       | 121/400 [09:46<24:34,  5.29s/it]

Epoch: 121/400 | Train Loss: 2.5213968910975382e-05 | Time: 6.04s (cuda) | LR: 1.7592186044416019e-06


 30%|███       | 122/400 [09:51<23:15,  5.02s/it]

Epoch: 122/400 | Train Loss: 2.372755807300564e-05 | Time: 4.39s (cuda) | LR: 1.7592186044416019e-06


 31%|███       | 123/400 [09:55<22:53,  4.96s/it]

Epoch: 123/400 | Train Loss: 2.3390650312649086e-05 | Time: 4.82s (cuda) | LR: 1.7592186044416019e-06


 31%|███       | 124/400 [10:01<23:05,  5.02s/it]

Epoch: 124/400 | Train Loss: 2.3360720661003143e-05 | Time: 5.15s (cuda) | LR: 1.7592186044416019e-06


 31%|███▏      | 125/400 [10:05<22:01,  4.81s/it]

Epoch: 125/400 | Train Loss: 2.334826422156766e-05 | Time: 4.31s (cuda) | LR: 1.7592186044416019e-06


 32%|███▏      | 126/400 [10:11<23:30,  5.15s/it]

Epoch: 126/400 | Train Loss: 2.3237666027853265e-05 | Time: 5.94s (cuda) | LR: 1.7592186044416019e-06


 32%|███▏      | 127/400 [10:15<22:21,  4.91s/it]

Epoch: 127/400 | Train Loss: 2.3160038836067542e-05 | Time: 4.37s (cuda) | LR: 1.7592186044416019e-06


 32%|███▏      | 128/400 [10:21<23:25,  5.17s/it]

Epoch: 128/400 | Train Loss: 2.3101991246221587e-05 | Time: 5.76s (cuda) | LR: 1.7592186044416019e-06


 32%|███▏      | 129/400 [10:26<23:09,  5.13s/it]

Epoch: 129/400 | Train Loss: 2.299077641509939e-05 | Time: 5.03s (cuda) | LR: 1.7592186044416019e-06


 32%|███▎      | 130/400 [10:31<23:08,  5.14s/it]

Epoch: 130/400 | Train Loss: 2.2959133275435306e-05 | Time: 5.17s (cuda) | LR: 1.7592186044416019e-06


 33%|███▎      | 131/400 [10:37<24:31,  5.47s/it]

Epoch: 131/400 | Train Loss: 2.284921720274724e-05 | Time: 6.23s (cuda) | LR: 1.7592186044416019e-06


 33%|███▎      | 132/400 [10:42<22:54,  5.13s/it]

Epoch: 132/400 | Train Loss: 2.2774260287405923e-05 | Time: 4.33s (cuda) | LR: 1.7592186044416019e-06


 33%|███▎      | 133/400 [10:47<23:18,  5.24s/it]

Epoch: 133/400 | Train Loss: 2.2696636733599007e-05 | Time: 5.49s (cuda) | LR: 1.7592186044416019e-06


 34%|███▎      | 134/400 [10:52<22:35,  5.09s/it]

Epoch: 134/400 | Train Loss: 2.2601423552259803e-05 | Time: 4.76s (cuda) | LR: 1.7592186044416019e-06


 34%|███▍      | 135/400 [10:58<23:25,  5.30s/it]

Epoch: 135/400 | Train Loss: 2.254030368931126e-05 | Time: 5.79s (cuda) | LR: 1.7592186044416019e-06


 34%|███▍      | 136/400 [11:03<23:01,  5.23s/it]

Epoch: 136/400 | Train Loss: 2.244394454464782e-05 | Time: 5.06s (cuda) | LR: 1.7592186044416019e-06


 34%|███▍      | 137/400 [11:08<22:54,  5.23s/it]

Epoch: 137/400 | Train Loss: 2.2393658582586795e-05 | Time: 5.21s (cuda) | LR: 1.7592186044416019e-06


 34%|███▍      | 138/400 [11:14<23:43,  5.43s/it]

Epoch: 138/400 | Train Loss: 2.229142774012871e-05 | Time: 5.91s (cuda) | LR: 1.7592186044416019e-06


 35%|███▍      | 139/400 [11:19<23:09,  5.33s/it]

Epoch: 139/400 | Train Loss: 2.2225756765692495e-05 | Time: 5.08s (cuda) | LR: 1.7592186044416019e-06


 35%|███▌      | 140/400 [11:25<23:28,  5.42s/it]

Epoch: 140/400 | Train Loss: 2.2138126951176673e-05 | Time: 5.63s (cuda) | LR: 1.7592186044416019e-06


 35%|███▌      | 141/400 [11:30<22:59,  5.33s/it]

Epoch: 141/400 | Train Loss: 2.2068787075113505e-05 | Time: 5.12s (cuda) | LR: 1.4073748835532816e-06


 36%|███▌      | 142/400 [11:35<22:28,  5.23s/it]

Epoch: 142/400 | Train Loss: 2.093835428240709e-05 | Time: 4.99s (cuda) | LR: 1.4073748835532816e-06


 36%|███▌      | 143/400 [11:41<23:10,  5.41s/it]

Epoch: 143/400 | Train Loss: 2.0752388081746176e-05 | Time: 5.83s (cuda) | LR: 1.4073748835532816e-06


 36%|███▌      | 144/400 [11:45<22:03,  5.17s/it]

Epoch: 144/400 | Train Loss: 2.0748540919157676e-05 | Time: 4.6s (cuda) | LR: 1.4073748835532816e-06


 36%|███▋      | 145/400 [11:51<22:55,  5.39s/it]

Epoch: 145/400 | Train Loss: 2.069990659947507e-05 | Time: 5.92s (cuda) | LR: 1.4073748835532816e-06


 36%|███▋      | 146/400 [11:56<22:12,  5.24s/it]

Epoch: 146/400 | Train Loss: 2.065049557131715e-05 | Time: 4.9s (cuda) | LR: 1.4073748835532816e-06


 37%|███▋      | 147/400 [12:02<23:04,  5.47s/it]

Epoch: 147/400 | Train Loss: 2.059682628896553e-05 | Time: 6.0s (cuda) | LR: 1.4073748835532816e-06


 37%|███▋      | 148/400 [12:07<22:48,  5.43s/it]

Epoch: 148/400 | Train Loss: 2.0553354261210188e-05 | Time: 5.34s (cuda) | LR: 1.4073748835532816e-06


 37%|███▋      | 149/400 [12:13<23:10,  5.54s/it]

Epoch: 149/400 | Train Loss: 2.052169293165207e-05 | Time: 5.79s (cuda) | LR: 1.4073748835532816e-06


 38%|███▊      | 150/400 [12:18<22:45,  5.46s/it]

Epoch: 150/400 | Train Loss: 2.042116830125451e-05 | Time: 5.27s (cuda) | LR: 1.4073748835532816e-06


 38%|███▊      | 151/400 [12:24<22:12,  5.35s/it]

Epoch: 151/400 | Train Loss: 2.0385317839100026e-05 | Time: 5.09s (cuda) | LR: 1.4073748835532816e-06


 38%|███▊      | 152/400 [12:29<22:10,  5.37s/it]

Epoch: 152/400 | Train Loss: 2.0319788745837286e-05 | Time: 5.4s (cuda) | LR: 1.4073748835532816e-06


 38%|███▊      | 153/400 [12:34<21:39,  5.26s/it]

Epoch: 153/400 | Train Loss: 2.0284942365833558e-05 | Time: 5.01s (cuda) | LR: 1.4073748835532816e-06


 38%|███▊      | 154/400 [12:40<22:00,  5.37s/it]

Epoch: 154/400 | Train Loss: 2.0223258616169915e-05 | Time: 5.62s (cuda) | LR: 1.4073748835532816e-06


 39%|███▉      | 155/400 [12:45<21:26,  5.25s/it]

Epoch: 155/400 | Train Loss: 2.0160954591119662e-05 | Time: 4.98s (cuda) | LR: 1.4073748835532816e-06


 39%|███▉      | 156/400 [12:49<20:39,  5.08s/it]

Epoch: 156/400 | Train Loss: 2.0099734683753923e-05 | Time: 4.68s (cuda) | LR: 1.4073748835532816e-06


 39%|███▉      | 157/400 [12:56<22:11,  5.48s/it]

Epoch: 157/400 | Train Loss: 2.0057537767570466e-05 | Time: 6.41s (cuda) | LR: 1.4073748835532816e-06


 40%|███▉      | 158/400 [13:00<21:19,  5.29s/it]

Epoch: 158/400 | Train Loss: 2.0026707716169767e-05 | Time: 4.83s (cuda) | LR: 1.4073748835532816e-06


 40%|███▉      | 159/400 [13:07<22:17,  5.55s/it]

Epoch: 159/400 | Train Loss: 1.9924809748772532e-05 | Time: 6.16s (cuda) | LR: 1.4073748835532816e-06


 40%|████      | 160/400 [13:11<21:11,  5.30s/it]

Epoch: 160/400 | Train Loss: 1.9888751921826042e-05 | Time: 4.72s (cuda) | LR: 1.4073748835532816e-06


 40%|████      | 161/400 [13:18<22:11,  5.57s/it]

Epoch: 161/400 | Train Loss: 1.985332892218139e-05 | Time: 6.2s (cuda) | LR: 1.1258999068426254e-06


 40%|████      | 162/400 [13:22<20:39,  5.21s/it]

Epoch: 162/400 | Train Loss: 1.898014852486085e-05 | Time: 4.36s (cuda) | LR: 1.1258999068426254e-06


 41%|████      | 163/400 [13:27<20:25,  5.17s/it]

Epoch: 163/400 | Train Loss: 1.8857215764001012e-05 | Time: 5.08s (cuda) | LR: 1.1258999068426254e-06


 41%|████      | 164/400 [13:32<20:17,  5.16s/it]

Epoch: 164/400 | Train Loss: 1.8835880837286822e-05 | Time: 5.13s (cuda) | LR: 1.1258999068426254e-06


 41%|████▏     | 165/400 [13:37<19:49,  5.06s/it]

Epoch: 165/400 | Train Loss: 1.880561831058003e-05 | Time: 4.83s (cuda) | LR: 1.1258999068426254e-06


 42%|████▏     | 166/400 [13:42<20:15,  5.19s/it]

Epoch: 166/400 | Train Loss: 1.878335933724884e-05 | Time: 5.5s (cuda) | LR: 1.1258999068426254e-06


 42%|████▏     | 167/400 [13:47<19:40,  5.07s/it]

Epoch: 167/400 | Train Loss: 1.8746561181615107e-05 | Time: 4.78s (cuda) | LR: 1.1258999068426254e-06


 42%|████▏     | 168/400 [13:52<19:02,  4.92s/it]

Epoch: 168/400 | Train Loss: 1.8727096176007763e-05 | Time: 4.59s (cuda) | LR: 1.1258999068426254e-06


 42%|████▏     | 169/400 [13:58<19:50,  5.15s/it]

Epoch: 169/400 | Train Loss: 1.86821780516766e-05 | Time: 5.69s (cuda) | LR: 1.1258999068426254e-06


 42%|████▎     | 170/400 [14:02<19:15,  5.03s/it]

Epoch: 170/400 | Train Loss: 1.8640021153260022e-05 | Time: 4.72s (cuda) | LR: 1.1258999068426254e-06


 43%|████▎     | 171/400 [14:08<20:27,  5.36s/it]

Epoch: 171/400 | Train Loss: 1.8604416254675016e-05 | Time: 6.14s (cuda) | LR: 1.1258999068426254e-06


 43%|████▎     | 172/400 [14:13<19:17,  5.08s/it]

Epoch: 172/400 | Train Loss: 1.8549317246652208e-05 | Time: 4.42s (cuda) | LR: 1.1258999068426254e-06


 43%|████▎     | 173/400 [14:19<20:09,  5.33s/it]

Epoch: 173/400 | Train Loss: 1.8507349523133598e-05 | Time: 5.9s (cuda) | LR: 1.1258999068426254e-06


 44%|████▎     | 174/400 [14:23<19:15,  5.11s/it]

Epoch: 174/400 | Train Loss: 1.847424755396787e-05 | Time: 4.61s (cuda) | LR: 1.1258999068426254e-06


 44%|████▍     | 175/400 [14:29<19:16,  5.14s/it]

Epoch: 175/400 | Train Loss: 1.8422626453684643e-05 | Time: 5.2s (cuda) | LR: 1.1258999068426254e-06


 44%|████▍     | 176/400 [14:35<20:15,  5.43s/it]

Epoch: 176/400 | Train Loss: 1.8381906556896865e-05 | Time: 6.09s (cuda) | LR: 1.1258999068426254e-06


 44%|████▍     | 177/400 [14:40<19:56,  5.36s/it]

Epoch: 177/400 | Train Loss: 1.8352595361648127e-05 | Time: 5.22s (cuda) | LR: 1.1258999068426254e-06


 44%|████▍     | 178/400 [14:46<20:45,  5.61s/it]

Epoch: 178/400 | Train Loss: 1.830416476877872e-05 | Time: 6.18s (cuda) | LR: 1.1258999068426254e-06


 45%|████▍     | 179/400 [14:51<20:24,  5.54s/it]

Epoch: 179/400 | Train Loss: 1.8285381884197704e-05 | Time: 5.38s (cuda) | LR: 1.1258999068426254e-06


 45%|████▌     | 180/400 [14:57<20:03,  5.47s/it]

Epoch: 180/400 | Train Loss: 1.824618266255129e-05 | Time: 5.3s (cuda) | LR: 1.1258999068426254e-06


 45%|████▌     | 181/400 [15:02<19:53,  5.45s/it]

Epoch: 181/400 | Train Loss: 1.8185275621362962e-05 | Time: 5.4s (cuda) | LR: 9.007199254741003e-07


 46%|████▌     | 182/400 [15:08<19:57,  5.49s/it]

Epoch: 182/400 | Train Loss: 1.7568847397342324e-05 | Time: 5.6s (cuda) | LR: 9.007199254741003e-07


 46%|████▌     | 183/400 [15:14<20:17,  5.61s/it]

Epoch: 183/400 | Train Loss: 1.7456683053751476e-05 | Time: 5.87s (cuda) | LR: 9.007199254741003e-07


 46%|████▌     | 184/400 [15:19<19:44,  5.49s/it]

Epoch: 184/400 | Train Loss: 1.7435741028748453e-05 | Time: 5.2s (cuda) | LR: 9.007199254741003e-07


 46%|████▋     | 185/400 [15:25<20:27,  5.71s/it]

Epoch: 185/400 | Train Loss: 1.7435188055969775e-05 | Time: 6.24s (cuda) | LR: 9.007199254741003e-07


 46%|████▋     | 186/400 [15:30<19:54,  5.58s/it]

Epoch: 186/400 | Train Loss: 1.7406615370418876e-05 | Time: 5.27s (cuda) | LR: 9.007199254741003e-07


 47%|████▋     | 187/400 [15:37<20:28,  5.77s/it]

Epoch: 187/400 | Train Loss: 1.738277933327481e-05 | Time: 6.2s (cuda) | LR: 9.007199254741003e-07


 47%|████▋     | 188/400 [15:42<19:49,  5.61s/it]

Epoch: 188/400 | Train Loss: 1.735585647111293e-05 | Time: 5.24s (cuda) | LR: 9.007199254741003e-07


 47%|████▋     | 189/400 [15:48<20:27,  5.82s/it]

Epoch: 189/400 | Train Loss: 1.731904012558516e-05 | Time: 6.3s (cuda) | LR: 9.007199254741003e-07


 48%|████▊     | 190/400 [15:53<19:45,  5.65s/it]

Epoch: 190/400 | Train Loss: 1.728749521134887e-05 | Time: 5.24s (cuda) | LR: 9.007199254741003e-07


 48%|████▊     | 191/400 [16:00<20:21,  5.84s/it]

Epoch: 191/400 | Train Loss: 1.7272610421059653e-05 | Time: 6.3s (cuda) | LR: 9.007199254741003e-07


 48%|████▊     | 192/400 [16:05<19:45,  5.70s/it]

Epoch: 192/400 | Train Loss: 1.7226189811481163e-05 | Time: 5.36s (cuda) | LR: 9.007199254741003e-07


 48%|████▊     | 193/400 [16:11<19:56,  5.78s/it]

Epoch: 193/400 | Train Loss: 1.720390355330892e-05 | Time: 5.96s (cuda) | LR: 9.007199254741003e-07


 48%|████▊     | 194/400 [16:17<19:45,  5.75s/it]

Epoch: 194/400 | Train Loss: 1.716169026622083e-05 | Time: 5.7s (cuda) | LR: 9.007199254741003e-07


 49%|████▉     | 195/400 [16:22<19:34,  5.73s/it]

Epoch: 195/400 | Train Loss: 1.7153448425233364e-05 | Time: 5.67s (cuda) | LR: 9.007199254741003e-07


 49%|████▉     | 196/400 [16:28<19:46,  5.82s/it]

Epoch: 196/400 | Train Loss: 1.711560253170319e-05 | Time: 6.01s (cuda) | LR: 9.007199254741003e-07


 49%|████▉     | 197/400 [16:34<19:27,  5.75s/it]

Epoch: 197/400 | Train Loss: 1.7093812857638113e-05 | Time: 5.59s (cuda) | LR: 9.007199254741003e-07


 50%|████▉     | 198/400 [16:40<19:50,  5.89s/it]

Epoch: 198/400 | Train Loss: 1.705670729279518e-05 | Time: 6.22s (cuda) | LR: 9.007199254741003e-07


 50%|████▉     | 199/400 [16:46<19:15,  5.75s/it]

Epoch: 199/400 | Train Loss: 1.703320594970137e-05 | Time: 5.41s (cuda) | LR: 9.007199254741003e-07


 50%|█████     | 200/400 [16:52<19:44,  5.92s/it]

Epoch: 200/400 | Train Loss: 1.6994883480947465e-05 | Time: 6.32s (cuda) | LR: 9.007199254741003e-07


 50%|█████     | 201/400 [16:57<19:01,  5.74s/it]

Epoch: 201/400 | Train Loss: 1.698289270279929e-05 | Time: 5.3s (cuda) | LR: 7.205759403792803e-07


 50%|█████     | 202/400 [17:04<19:36,  5.94s/it]

Epoch: 202/400 | Train Loss: 1.6477326425956562e-05 | Time: 6.43s (cuda) | LR: 7.205759403792803e-07


 51%|█████     | 203/400 [17:09<18:56,  5.77s/it]

Epoch: 203/400 | Train Loss: 1.640461960050743e-05 | Time: 5.37s (cuda) | LR: 7.205759403792803e-07


 51%|█████     | 204/400 [17:15<19:26,  5.95s/it]

Epoch: 204/400 | Train Loss: 1.639921174501069e-05 | Time: 6.37s (cuda) | LR: 7.205759403792803e-07


 51%|█████▏    | 205/400 [17:21<18:45,  5.77s/it]

Epoch: 205/400 | Train Loss: 1.6379868611693382e-05 | Time: 5.35s (cuda) | LR: 7.205759403792803e-07


 52%|█████▏    | 206/400 [17:27<19:00,  5.88s/it]

Epoch: 206/400 | Train Loss: 1.6367968783015385e-05 | Time: 6.13s (cuda) | LR: 7.205759403792803e-07


 52%|█████▏    | 207/400 [17:32<18:35,  5.78s/it]

Epoch: 207/400 | Train Loss: 1.6344598407158628e-05 | Time: 5.54s (cuda) | LR: 7.205759403792803e-07


 52%|█████▏    | 208/400 [17:38<18:04,  5.65s/it]

Epoch: 208/400 | Train Loss: 1.632122257433366e-05 | Time: 5.35s (cuda) | LR: 7.205759403792803e-07


 52%|█████▏    | 209/400 [17:44<18:24,  5.78s/it]

Epoch: 209/400 | Train Loss: 1.6296458852593787e-05 | Time: 6.1s (cuda) | LR: 7.205759403792803e-07


 52%|█████▎    | 210/400 [17:49<17:58,  5.68s/it]

Epoch: 210/400 | Train Loss: 1.6283422155538574e-05 | Time: 5.42s (cuda) | LR: 7.205759403792803e-07


 53%|█████▎    | 211/400 [17:56<18:27,  5.86s/it]

Epoch: 211/400 | Train Loss: 1.6261270502582192e-05 | Time: 6.28s (cuda) | LR: 7.205759403792803e-07


 53%|█████▎    | 212/400 [18:01<17:44,  5.66s/it]

Epoch: 212/400 | Train Loss: 1.6226018487941474e-05 | Time: 5.2s (cuda) | LR: 7.205759403792803e-07


 53%|█████▎    | 213/400 [18:07<18:15,  5.86s/it]

Epoch: 213/400 | Train Loss: 1.6216363292187452e-05 | Time: 6.32s (cuda) | LR: 7.205759403792803e-07


 54%|█████▎    | 214/400 [18:12<17:37,  5.68s/it]

Epoch: 214/400 | Train Loss: 1.618761416466441e-05 | Time: 5.28s (cuda) | LR: 7.205759403792803e-07


 54%|█████▍    | 215/400 [18:19<17:59,  5.84s/it]

Epoch: 215/400 | Train Loss: 1.6171448805835098e-05 | Time: 6.19s (cuda) | LR: 7.205759403792803e-07


 54%|█████▍    | 216/400 [18:24<17:21,  5.66s/it]

Epoch: 216/400 | Train Loss: 1.6150384908542037e-05 | Time: 5.24s (cuda) | LR: 7.205759403792803e-07


 54%|█████▍    | 217/400 [18:30<17:45,  5.82s/it]

Epoch: 217/400 | Train Loss: 1.6122909073601477e-05 | Time: 6.2s (cuda) | LR: 7.205759403792803e-07


 55%|█████▍    | 218/400 [18:36<17:25,  5.74s/it]

Epoch: 218/400 | Train Loss: 1.6100289940368384e-05 | Time: 5.56s (cuda) | LR: 7.205759403792803e-07


 55%|█████▍    | 219/400 [18:41<17:19,  5.74s/it]

Epoch: 219/400 | Train Loss: 1.6083329683169723e-05 | Time: 5.74s (cuda) | LR: 7.205759403792803e-07


 55%|█████▌    | 220/400 [18:47<17:18,  5.77s/it]

Epoch: 220/400 | Train Loss: 1.6072597645688802e-05 | Time: 5.82s (cuda) | LR: 7.205759403792803e-07


 55%|█████▌    | 221/400 [18:52<16:40,  5.59s/it]

Epoch: 221/400 | Train Loss: 1.60305826284457e-05 | Time: 5.17s (cuda) | LR: 5.764607523034243e-07


 56%|█████▌    | 222/400 [18:58<17:04,  5.75s/it]

Epoch: 222/400 | Train Loss: 1.5653013178962283e-05 | Time: 6.14s (cuda) | LR: 5.764607523034243e-07


 56%|█████▌    | 223/400 [19:04<16:43,  5.67s/it]

Epoch: 223/400 | Train Loss: 1.5600406186422333e-05 | Time: 5.46s (cuda) | LR: 5.764607523034243e-07


 56%|█████▌    | 224/400 [19:10<17:01,  5.80s/it]

Epoch: 224/400 | Train Loss: 1.5598052414134145e-05 | Time: 6.12s (cuda) | LR: 5.764607523034243e-07


 56%|█████▋    | 225/400 [19:15<16:22,  5.61s/it]

Epoch: 225/400 | Train Loss: 1.5581827028654516e-05 | Time: 5.17s (cuda) | LR: 5.764607523034243e-07


 56%|█████▋    | 226/400 [19:21<16:45,  5.78s/it]

Epoch: 226/400 | Train Loss: 1.55629913933808e-05 | Time: 6.16s (cuda) | LR: 5.764607523034243e-07


 57%|█████▋    | 227/400 [19:26<16:07,  5.59s/it]

Epoch: 227/400 | Train Loss: 1.5559653547825292e-05 | Time: 5.15s (cuda) | LR: 5.764607523034243e-07


 57%|█████▋    | 228/400 [19:33<16:37,  5.80s/it]

Epoch: 228/400 | Train Loss: 1.5536395949311554e-05 | Time: 6.28s (cuda) | LR: 5.764607523034243e-07


 57%|█████▋    | 229/400 [19:38<16:10,  5.68s/it]

Epoch: 229/400 | Train Loss: 1.551911554997787e-05 | Time: 5.38s (cuda) | LR: 5.764607523034243e-07


 57%|█████▊    | 230/400 [19:44<15:55,  5.62s/it]

Epoch: 230/400 | Train Loss: 1.550564957142342e-05 | Time: 5.48s (cuda) | LR: 5.764607523034243e-07


 58%|█████▊    | 231/400 [19:50<16:08,  5.73s/it]

Epoch: 231/400 | Train Loss: 1.548731052025687e-05 | Time: 5.99s (cuda) | LR: 5.764607523034243e-07


 58%|█████▊    | 232/400 [19:55<15:39,  5.59s/it]

Epoch: 232/400 | Train Loss: 1.546035855426453e-05 | Time: 5.26s (cuda) | LR: 5.764607523034243e-07


 58%|█████▊    | 233/400 [20:01<16:02,  5.77s/it]

Epoch: 233/400 | Train Loss: 1.5454821550520137e-05 | Time: 6.17s (cuda) | LR: 5.764607523034243e-07


 58%|█████▊    | 234/400 [20:06<15:34,  5.63s/it]

Epoch: 234/400 | Train Loss: 1.5439491107827052e-05 | Time: 5.31s (cuda) | LR: 5.764607523034243e-07


 59%|█████▉    | 235/400 [20:13<16:02,  5.83s/it]

Epoch: 235/400 | Train Loss: 1.5422703654621728e-05 | Time: 6.3s (cuda) | LR: 5.764607523034243e-07


 59%|█████▉    | 236/400 [20:18<15:25,  5.65s/it]

Epoch: 236/400 | Train Loss: 1.5403767974930815e-05 | Time: 5.21s (cuda) | LR: 5.764607523034243e-07


 59%|█████▉    | 237/400 [20:24<15:55,  5.86s/it]

Epoch: 237/400 | Train Loss: 1.538057767902501e-05 | Time: 6.37s (cuda) | LR: 5.764607523034243e-07


 60%|█████▉    | 238/400 [20:30<15:32,  5.76s/it]

Epoch: 238/400 | Train Loss: 1.536516356281936e-05 | Time: 5.51s (cuda) | LR: 5.764607523034243e-07


 60%|█████▉    | 239/400 [20:36<16:00,  5.96s/it]

Epoch: 239/400 | Train Loss: 1.5356450603576377e-05 | Time: 6.44s (cuda) | LR: 5.764607523034243e-07


 60%|██████    | 240/400 [20:42<15:25,  5.79s/it]

Epoch: 240/400 | Train Loss: 1.5345985957537778e-05 | Time: 5.37s (cuda) | LR: 5.764607523034243e-07


 60%|██████    | 241/400 [20:48<15:26,  5.82s/it]

Epoch: 241/400 | Train Loss: 1.5321200407925062e-05 | Time: 5.91s (cuda) | LR: 4.6116860184273944e-07


 60%|██████    | 242/400 [20:53<15:09,  5.75s/it]

Epoch: 242/400 | Train Loss: 1.5019240890978836e-05 | Time: 5.59s (cuda) | LR: 4.6116860184273944e-07


 61%|██████    | 243/400 [20:58<14:38,  5.60s/it]

Epoch: 243/400 | Train Loss: 1.4984011613705661e-05 | Time: 5.23s (cuda) | LR: 4.6116860184273944e-07


 61%|██████    | 244/400 [21:04<14:53,  5.73s/it]

Epoch: 244/400 | Train Loss: 1.4972290955483913e-05 | Time: 6.04s (cuda) | LR: 4.6116860184273944e-07


 61%|██████▏   | 245/400 [21:10<14:35,  5.65s/it]

Epoch: 245/400 | Train Loss: 1.4961258784751408e-05 | Time: 5.45s (cuda) | LR: 4.6116860184273944e-07


 62%|██████▏   | 246/400 [21:16<14:55,  5.82s/it]

Epoch: 246/400 | Train Loss: 1.4945248040021397e-05 | Time: 6.22s (cuda) | LR: 4.6116860184273944e-07


 62%|██████▏   | 247/400 [21:21<14:30,  5.69s/it]

Epoch: 247/400 | Train Loss: 1.4939962056814693e-05 | Time: 5.38s (cuda) | LR: 4.6116860184273944e-07


 62%|██████▏   | 248/400 [21:28<14:53,  5.88s/it]

Epoch: 248/400 | Train Loss: 1.4923879462003242e-05 | Time: 6.31s (cuda) | LR: 4.6116860184273944e-07


 62%|██████▏   | 249/400 [21:33<14:25,  5.73s/it]

Epoch: 249/400 | Train Loss: 1.4913288396201096e-05 | Time: 5.4s (cuda) | LR: 4.6116860184273944e-07


 62%|██████▎   | 250/400 [21:40<14:50,  5.94s/it]

Epoch: 250/400 | Train Loss: 1.489996247983072e-05 | Time: 6.41s (cuda) | LR: 4.6116860184273944e-07


 63%|██████▎   | 251/400 [21:45<14:19,  5.77s/it]

Epoch: 251/400 | Train Loss: 1.4888839359628037e-05 | Time: 5.38s (cuda) | LR: 4.6116860184273944e-07


 63%|██████▎   | 252/400 [21:51<14:39,  5.94s/it]

Epoch: 252/400 | Train Loss: 1.4876803106744774e-05 | Time: 6.34s (cuda) | LR: 4.6116860184273944e-07


 63%|██████▎   | 253/400 [21:57<14:13,  5.81s/it]

Epoch: 253/400 | Train Loss: 1.4858919712423813e-05 | Time: 5.49s (cuda) | LR: 4.6116860184273944e-07


 64%|██████▎   | 254/400 [22:03<14:19,  5.89s/it]

Epoch: 254/400 | Train Loss: 1.4849189938104246e-05 | Time: 6.08s (cuda) | LR: 4.6116860184273944e-07


 64%|██████▍   | 255/400 [22:09<14:17,  5.92s/it]

Epoch: 255/400 | Train Loss: 1.484136373619549e-05 | Time: 5.98s (cuda) | LR: 4.6116860184273944e-07


 64%|██████▍   | 256/400 [22:14<13:51,  5.77s/it]

Epoch: 256/400 | Train Loss: 1.4823101992078591e-05 | Time: 5.44s (cuda) | LR: 4.6116860184273944e-07


 64%|██████▍   | 257/400 [22:20<14:00,  5.88s/it]

Epoch: 257/400 | Train Loss: 1.481049730500672e-05 | Time: 6.11s (cuda) | LR: 4.6116860184273944e-07


 64%|██████▍   | 258/400 [22:26<13:34,  5.74s/it]

Epoch: 258/400 | Train Loss: 1.4797284165979363e-05 | Time: 5.41s (cuda) | LR: 4.6116860184273944e-07


 65%|██████▍   | 259/400 [22:32<13:56,  5.93s/it]

Epoch: 259/400 | Train Loss: 1.4789140550419688e-05 | Time: 6.37s (cuda) | LR: 4.6116860184273944e-07


 65%|██████▌   | 260/400 [22:38<13:35,  5.83s/it]

Epoch: 260/400 | Train Loss: 1.4771332644158974e-05 | Time: 5.58s (cuda) | LR: 4.6116860184273944e-07


 65%|██████▌   | 261/400 [22:44<13:48,  5.96s/it]

Epoch: 261/400 | Train Loss: 1.4755624761164654e-05 | Time: 6.29s (cuda) | LR: 3.689348814741916e-07


 66%|██████▌   | 262/400 [22:49<13:20,  5.80s/it]

Epoch: 262/400 | Train Loss: 1.4517327144858427e-05 | Time: 5.42s (cuda) | LR: 3.689348814741916e-07


 66%|██████▌   | 263/400 [22:56<13:35,  5.95s/it]

Epoch: 263/400 | Train Loss: 1.4492721675196663e-05 | Time: 6.29s (cuda) | LR: 3.689348814741916e-07


 66%|██████▌   | 264/400 [23:01<13:04,  5.77s/it]

Epoch: 264/400 | Train Loss: 1.4484924577118363e-05 | Time: 5.34s (cuda) | LR: 3.689348814741916e-07


 66%|██████▋   | 265/400 [23:07<13:17,  5.91s/it]

Epoch: 265/400 | Train Loss: 1.4474664567387663e-05 | Time: 6.24s (cuda) | LR: 3.689348814741916e-07


 66%|██████▋   | 266/400 [23:13<12:53,  5.77s/it]

Epoch: 266/400 | Train Loss: 1.4469263078353833e-05 | Time: 5.45s (cuda) | LR: 3.689348814741916e-07


 67%|██████▋   | 267/400 [23:19<12:58,  5.85s/it]

Epoch: 267/400 | Train Loss: 1.4464976629824378e-05 | Time: 6.04s (cuda) | LR: 3.689348814741916e-07


 67%|██████▋   | 268/400 [23:24<12:34,  5.72s/it]

Epoch: 268/400 | Train Loss: 1.4447194189415313e-05 | Time: 5.39s (cuda) | LR: 3.689348814741916e-07


 67%|██████▋   | 269/400 [23:30<12:15,  5.62s/it]

Epoch: 269/400 | Train Loss: 1.4443614418269135e-05 | Time: 5.38s (cuda) | LR: 3.689348814741916e-07


 68%|██████▊   | 270/400 [23:36<12:26,  5.74s/it]

Epoch: 270/400 | Train Loss: 1.4430660485231783e-05 | Time: 6.03s (cuda) | LR: 3.689348814741916e-07


 68%|██████▊   | 271/400 [23:41<12:04,  5.61s/it]

Epoch: 271/400 | Train Loss: 1.4422270396607928e-05 | Time: 5.32s (cuda) | LR: 3.689348814741916e-07


 68%|██████▊   | 272/400 [23:47<12:20,  5.79s/it]

Epoch: 272/400 | Train Loss: 1.4409599316422828e-05 | Time: 6.19s (cuda) | LR: 3.689348814741916e-07


 68%|██████▊   | 273/400 [23:52<11:56,  5.64s/it]

Epoch: 273/400 | Train Loss: 1.4397745871974621e-05 | Time: 5.31s (cuda) | LR: 3.689348814741916e-07


 68%|██████▊   | 274/400 [23:59<12:14,  5.83s/it]

Epoch: 274/400 | Train Loss: 1.4394002391782124e-05 | Time: 6.27s (cuda) | LR: 3.689348814741916e-07


 69%|██████▉   | 275/400 [24:04<11:45,  5.64s/it]

Epoch: 275/400 | Train Loss: 1.43840898090275e-05 | Time: 5.2s (cuda) | LR: 3.689348814741916e-07


 69%|██████▉   | 276/400 [24:10<12:05,  5.85s/it]

Epoch: 276/400 | Train Loss: 1.4367744370247237e-05 | Time: 6.34s (cuda) | LR: 3.689348814741916e-07


 69%|██████▉   | 277/400 [24:16<11:42,  5.71s/it]

Epoch: 277/400 | Train Loss: 1.4358242879097816e-05 | Time: 5.39s (cuda) | LR: 3.689348814741916e-07


 70%|██████▉   | 278/400 [24:21<11:04,  5.45s/it]

Epoch: 278/400 | Train Loss: 1.4354263839777559e-05 | Time: 4.82s (cuda) | LR: 3.689348814741916e-07


 70%|██████▉   | 279/400 [24:25<10:37,  5.27s/it]

Epoch: 279/400 | Train Loss: 1.4343751900014468e-05 | Time: 4.85s (cuda) | LR: 3.689348814741916e-07


 70%|███████   | 280/400 [24:30<10:00,  5.01s/it]

Epoch: 280/400 | Train Loss: 1.433630950486986e-05 | Time: 4.4s (cuda) | LR: 3.689348814741916e-07


 70%|███████   | 281/400 [24:35<10:08,  5.11s/it]

Epoch: 281/400 | Train Loss: 1.4324113180919085e-05 | Time: 5.35s (cuda) | LR: 2.9514790517935326e-07


 70%|███████   | 282/400 [24:40<09:42,  4.94s/it]

Epoch: 282/400 | Train Loss: 1.4127352187642828e-05 | Time: 4.53s (cuda) | LR: 2.9514790517935326e-07


 71%|███████   | 283/400 [24:44<09:19,  4.78s/it]

Epoch: 283/400 | Train Loss: 1.4109450603427831e-05 | Time: 4.41s (cuda) | LR: 2.9514790517935326e-07


 71%|███████   | 284/400 [24:49<09:32,  4.94s/it]

Epoch: 284/400 | Train Loss: 1.4105437912803609e-05 | Time: 5.31s (cuda) | LR: 2.9514790517935326e-07


 71%|███████▏  | 285/400 [24:54<09:07,  4.76s/it]

Epoch: 285/400 | Train Loss: 1.4094663129071705e-05 | Time: 4.35s (cuda) | LR: 2.9514790517935326e-07


 72%|███████▏  | 286/400 [24:59<09:16,  4.88s/it]

Epoch: 286/400 | Train Loss: 1.4092106539465021e-05 | Time: 5.17s (cuda) | LR: 2.9514790517935326e-07


 72%|███████▏  | 287/400 [25:03<08:58,  4.77s/it]

Epoch: 287/400 | Train Loss: 1.4081232620810624e-05 | Time: 4.49s (cuda) | LR: 2.9514790517935326e-07


 72%|███████▏  | 288/400 [25:08<08:42,  4.66s/it]

Epoch: 288/400 | Train Loss: 1.4076247680350207e-05 | Time: 4.42s (cuda) | LR: 2.9514790517935326e-07


 72%|███████▏  | 289/400 [25:13<09:05,  4.91s/it]

Epoch: 289/400 | Train Loss: 1.406753835908603e-05 | Time: 5.49s (cuda) | LR: 2.9514790517935326e-07


 72%|███████▎  | 290/400 [25:18<08:43,  4.76s/it]

Epoch: 290/400 | Train Loss: 1.4060229659662582e-05 | Time: 4.42s (cuda) | LR: 2.9514790517935326e-07


 73%|███████▎  | 291/400 [25:22<08:29,  4.67s/it]

Epoch: 291/400 | Train Loss: 1.4051034668227658e-05 | Time: 4.46s (cuda) | LR: 2.9514790517935326e-07


 73%|███████▎  | 292/400 [25:27<08:39,  4.81s/it]

Epoch: 292/400 | Train Loss: 1.4047271179151721e-05 | Time: 5.14s (cuda) | LR: 2.9514790517935326e-07


 73%|███████▎  | 293/400 [25:32<08:20,  4.68s/it]

Epoch: 293/400 | Train Loss: 1.403804708388634e-05 | Time: 4.36s (cuda) | LR: 2.9514790517935326e-07


 74%|███████▎  | 294/400 [25:37<08:37,  4.88s/it]

Epoch: 294/400 | Train Loss: 1.4031546925252769e-05 | Time: 5.35s (cuda) | LR: 2.9514790517935326e-07


 74%|███████▍  | 295/400 [25:42<08:20,  4.77s/it]

Epoch: 295/400 | Train Loss: 1.4024826668901369e-05 | Time: 4.51s (cuda) | LR: 2.9514790517935326e-07


 74%|███████▍  | 296/400 [25:46<08:04,  4.66s/it]

Epoch: 296/400 | Train Loss: 1.40099655254744e-05 | Time: 4.4s (cuda) | LR: 2.9514790517935326e-07


 74%|███████▍  | 297/400 [25:51<08:21,  4.86s/it]

Epoch: 297/400 | Train Loss: 1.401130975864362e-05 | Time: 5.34s (cuda) | LR: 2.3611832414348262e-07


 74%|███████▍  | 298/400 [25:56<08:02,  4.73s/it]

Epoch: 298/400 | Train Loss: 1.384855386277195e-05 | Time: 4.42s (cuda) | LR: 2.3611832414348262e-07


 75%|███████▍  | 299/400 [26:00<07:58,  4.74s/it]

Epoch: 299/400 | Train Loss: 1.383549260935979e-05 | Time: 4.74s (cuda) | LR: 2.3611832414348262e-07


 75%|███████▌  | 300/400 [26:05<08:00,  4.80s/it]

Epoch: 300/400 | Train Loss: 1.3830867828801274e-05 | Time: 4.95s (cuda) | LR: 2.3611832414348262e-07


 75%|███████▌  | 301/400 [26:10<07:43,  4.68s/it]

Epoch: 301/400 | Train Loss: 1.3822254004480783e-05 | Time: 4.41s (cuda) | LR: 1.8889465931478612e-07


 76%|███████▌  | 302/400 [26:15<08:00,  4.90s/it]

Epoch: 302/400 | Train Loss: 1.3697268514079042e-05 | Time: 5.41s (cuda) | LR: 1.8889465931478612e-07


 76%|███████▌  | 303/400 [26:20<07:42,  4.77s/it]

Epoch: 303/400 | Train Loss: 1.368712037219666e-05 | Time: 4.46s (cuda) | LR: 1.8889465931478612e-07


 76%|███████▌  | 304/400 [26:24<07:26,  4.65s/it]

Epoch: 304/400 | Train Loss: 1.3684229998034425e-05 | Time: 4.37s (cuda) | LR: 1.8889465931478612e-07


 76%|███████▋  | 305/400 [26:29<07:39,  4.83s/it]

Epoch: 305/400 | Train Loss: 1.3680317351827398e-05 | Time: 5.26s (cuda) | LR: 1.8889465931478612e-07


 76%|███████▋  | 306/400 [26:34<07:22,  4.71s/it]

Epoch: 306/400 | Train Loss: 1.3673301509697922e-05 | Time: 4.41s (cuda) | LR: 1.8889465931478612e-07


 77%|███████▋  | 307/400 [26:39<07:33,  4.88s/it]

Epoch: 307/400 | Train Loss: 1.3668495739693753e-05 | Time: 5.28s (cuda) | LR: 1.8889465931478612e-07


 77%|███████▋  | 308/400 [26:43<07:18,  4.77s/it]

Epoch: 308/400 | Train Loss: 1.366355354548432e-05 | Time: 4.49s (cuda) | LR: 1.8889465931478612e-07


 77%|███████▋  | 309/400 [26:48<07:02,  4.64s/it]

Epoch: 309/400 | Train Loss: 1.3661232515005395e-05 | Time: 4.35s (cuda) | LR: 1.8889465931478612e-07


 78%|███████▊  | 310/400 [26:53<07:19,  4.89s/it]

Epoch: 310/400 | Train Loss: 1.3655042494065128e-05 | Time: 5.46s (cuda) | LR: 1.8889465931478612e-07


 78%|███████▊  | 311/400 [26:58<07:01,  4.74s/it]

Epoch: 311/400 | Train Loss: 1.3651213521370664e-05 | Time: 4.38s (cuda) | LR: 1.8889465931478612e-07


 78%|███████▊  | 312/400 [27:03<07:00,  4.77s/it]

Epoch: 312/400 | Train Loss: 1.3646793377120048e-05 | Time: 4.85s (cuda) | LR: 1.8889465931478612e-07


 78%|███████▊  | 313/400 [27:08<07:00,  4.83s/it]

Epoch: 313/400 | Train Loss: 1.3643195416079834e-05 | Time: 4.96s (cuda) | LR: 1.8889465931478612e-07


 78%|███████▊  | 314/400 [27:12<06:44,  4.70s/it]

Epoch: 314/400 | Train Loss: 1.3636903531732969e-05 | Time: 4.41s (cuda) | LR: 1.8889465931478612e-07


 79%|███████▉  | 315/400 [27:17<06:58,  4.92s/it]

Epoch: 315/400 | Train Loss: 1.3630782632390037e-05 | Time: 5.43s (cuda) | LR: 1.8889465931478612e-07


 79%|███████▉  | 316/400 [27:22<06:40,  4.77s/it]

Epoch: 316/400 | Train Loss: 1.3626215149997734e-05 | Time: 4.42s (cuda) | LR: 1.8889465931478612e-07


 79%|███████▉  | 317/400 [27:26<06:26,  4.65s/it]

Epoch: 317/400 | Train Loss: 1.361966224067146e-05 | Time: 4.37s (cuda) | LR: 1.8889465931478612e-07


 80%|███████▉  | 318/400 [27:31<06:37,  4.85s/it]

Epoch: 318/400 | Train Loss: 1.361406611977145e-05 | Time: 5.32s (cuda) | LR: 1.8889465931478612e-07


 80%|███████▉  | 319/400 [27:36<06:22,  4.72s/it]

Epoch: 319/400 | Train Loss: 1.3614780073112343e-05 | Time: 4.41s (cuda) | LR: 1.511157274518289e-07


 80%|████████  | 320/400 [27:41<06:29,  4.87s/it]

Epoch: 320/400 | Train Loss: 1.3509987184079364e-05 | Time: 5.23s (cuda) | LR: 1.511157274518289e-07


 80%|████████  | 321/400 [27:46<06:20,  4.81s/it]

Epoch: 321/400 | Train Loss: 1.350079764961265e-05 | Time: 4.67s (cuda) | LR: 1.208925819614631e-07


 80%|████████  | 322/400 [27:50<06:05,  4.69s/it]

Epoch: 322/400 | Train Loss: 1.3419981769402511e-05 | Time: 4.39s (cuda) | LR: 1.208925819614631e-07


 81%|████████  | 323/400 [27:55<06:14,  4.87s/it]

Epoch: 323/400 | Train Loss: 1.3412685802904889e-05 | Time: 5.28s (cuda) | LR: 1.208925819614631e-07


 81%|████████  | 324/400 [28:00<05:58,  4.72s/it]

Epoch: 324/400 | Train Loss: 1.340821290796157e-05 | Time: 4.38s (cuda) | LR: 1.208925819614631e-07


 81%|████████▏ | 325/400 [28:04<05:50,  4.68s/it]

Epoch: 325/400 | Train Loss: 1.3406170182861388e-05 | Time: 4.58s (cuda) | LR: 1.208925819614631e-07


 82%|████████▏ | 326/400 [28:10<05:55,  4.80s/it]

Epoch: 326/400 | Train Loss: 1.3404265700955875e-05 | Time: 5.08s (cuda) | LR: 1.208925819614631e-07


 82%|████████▏ | 327/400 [28:14<05:43,  4.71s/it]

Epoch: 327/400 | Train Loss: 1.3401471733232029e-05 | Time: 4.5s (cuda) | LR: 1.208925819614631e-07


 82%|████████▏ | 328/400 [28:19<05:52,  4.89s/it]

Epoch: 328/400 | Train Loss: 1.3398956070886925e-05 | Time: 5.32s (cuda) | LR: 1.208925819614631e-07


 82%|████████▏ | 329/400 [28:24<05:36,  4.74s/it]

Epoch: 329/400 | Train Loss: 1.3395871974353213e-05 | Time: 4.39s (cuda) | LR: 1.208925819614631e-07


 82%|████████▎ | 330/400 [28:28<05:24,  4.63s/it]

Epoch: 330/400 | Train Loss: 1.3392709661275148e-05 | Time: 4.37s (cuda) | LR: 1.208925819614631e-07


 83%|████████▎ | 331/400 [28:33<05:34,  4.84s/it]

Epoch: 331/400 | Train Loss: 1.3388405932346359e-05 | Time: 5.34s (cuda) | LR: 1.208925819614631e-07


 83%|████████▎ | 332/400 [28:38<05:20,  4.71s/it]

Epoch: 332/400 | Train Loss: 1.3386905266088434e-05 | Time: 4.39s (cuda) | LR: 1.208925819614631e-07


 83%|████████▎ | 333/400 [28:43<05:20,  4.78s/it]

Epoch: 333/400 | Train Loss: 1.3382210454437882e-05 | Time: 4.94s (cuda) | LR: 1.208925819614631e-07


 84%|████████▎ | 334/400 [28:48<05:17,  4.82s/it]

Epoch: 334/400 | Train Loss: 1.3379920346778817e-05 | Time: 4.91s (cuda) | LR: 1.208925819614631e-07


 84%|████████▍ | 335/400 [28:52<05:07,  4.73s/it]

Epoch: 335/400 | Train Loss: 1.3376793503994122e-05 | Time: 4.51s (cuda) | LR: 1.208925819614631e-07


 84%|████████▍ | 336/400 [28:58<05:16,  4.95s/it]

Epoch: 336/400 | Train Loss: 1.3373282854445279e-05 | Time: 5.47s (cuda) | LR: 1.208925819614631e-07


 84%|████████▍ | 337/400 [29:02<05:01,  4.79s/it]

Epoch: 337/400 | Train Loss: 1.3372472494665999e-05 | Time: 4.4s (cuda) | LR: 1.208925819614631e-07


 84%|████████▍ | 338/400 [29:07<04:50,  4.68s/it]

Epoch: 338/400 | Train Loss: 1.3366201528697275e-05 | Time: 4.45s (cuda) | LR: 1.208925819614631e-07


 85%|████████▍ | 339/400 [29:12<04:59,  4.92s/it]

Epoch: 339/400 | Train Loss: 1.3365237464313395e-05 | Time: 5.46s (cuda) | LR: 1.208925819614631e-07


 85%|████████▌ | 340/400 [29:17<04:48,  4.81s/it]

Epoch: 340/400 | Train Loss: 1.3361672245082445e-05 | Time: 4.57s (cuda) | LR: 1.208925819614631e-07


 85%|████████▌ | 341/400 [29:22<04:57,  5.03s/it]

Epoch: 341/400 | Train Loss: 1.3359557669900823e-05 | Time: 5.55s (cuda) | LR: 9.67140655691705e-08


 86%|████████▌ | 342/400 [29:27<04:43,  4.89s/it]

Epoch: 342/400 | Train Loss: 1.3289914932101965e-05 | Time: 4.56s (cuda) | LR: 9.67140655691705e-08


 86%|████████▌ | 343/400 [29:31<04:30,  4.75s/it]

Epoch: 343/400 | Train Loss: 1.3287921319715679e-05 | Time: 4.41s (cuda) | LR: 9.67140655691705e-08


 86%|████████▌ | 344/400 [29:37<04:37,  4.96s/it]

Epoch: 344/400 | Train Loss: 1.3283965927257668e-05 | Time: 5.45s (cuda) | LR: 9.67140655691705e-08


 86%|████████▋ | 345/400 [29:41<04:25,  4.82s/it]

Epoch: 345/400 | Train Loss: 1.3281951396493241e-05 | Time: 4.51s (cuda) | LR: 9.67140655691705e-08


 86%|████████▋ | 346/400 [29:47<04:35,  5.11s/it]

Epoch: 346/400 | Train Loss: 1.3280460734677035e-05 | Time: 5.77s (cuda) | LR: 9.67140655691705e-08


 87%|████████▋ | 347/400 [29:51<04:22,  4.95s/it]

Epoch: 347/400 | Train Loss: 1.3278311598696746e-05 | Time: 4.59s (cuda) | LR: 9.67140655691705e-08


 87%|████████▋ | 348/400 [29:56<04:10,  4.82s/it]

Epoch: 348/400 | Train Loss: 1.327427617070498e-05 | Time: 4.52s (cuda) | LR: 9.67140655691705e-08


 87%|████████▋ | 349/400 [30:01<04:15,  5.01s/it]

Epoch: 349/400 | Train Loss: 1.3272289834276307e-05 | Time: 5.45s (cuda) | LR: 9.67140655691705e-08


 88%|████████▊ | 350/400 [30:06<04:01,  4.83s/it]

Epoch: 350/400 | Train Loss: 1.327127029071562e-05 | Time: 4.39s (cuda) | LR: 9.67140655691705e-08


 88%|████████▊ | 351/400 [30:11<03:58,  4.88s/it]

Epoch: 351/400 | Train Loss: 1.3267770555103198e-05 | Time: 4.98s (cuda) | LR: 9.67140655691705e-08


 88%|████████▊ | 352/400 [30:16<03:54,  4.88s/it]

Epoch: 352/400 | Train Loss: 1.3266570022096857e-05 | Time: 4.87s (cuda) | LR: 9.67140655691705e-08


 88%|████████▊ | 353/400 [30:20<03:42,  4.73s/it]

Epoch: 353/400 | Train Loss: 1.3262088941701222e-05 | Time: 4.39s (cuda) | LR: 9.67140655691705e-08


 88%|████████▊ | 354/400 [30:25<03:47,  4.94s/it]

Epoch: 354/400 | Train Loss: 1.3258615581435151e-05 | Time: 5.43s (cuda) | LR: 9.67140655691705e-08


 89%|████████▉ | 355/400 [30:30<03:35,  4.79s/it]

Epoch: 355/400 | Train Loss: 1.3257325008453336e-05 | Time: 4.44s (cuda) | LR: 9.67140655691705e-08


 89%|████████▉ | 356/400 [30:34<03:28,  4.73s/it]

Epoch: 356/400 | Train Loss: 1.3254212717583869e-05 | Time: 4.58s (cuda) | LR: 9.67140655691705e-08


 89%|████████▉ | 357/400 [30:40<03:31,  4.92s/it]

Epoch: 357/400 | Train Loss: 1.3253036740934476e-05 | Time: 5.35s (cuda) | LR: 9.67140655691705e-08


 90%|████████▉ | 358/400 [30:44<03:20,  4.77s/it]

Epoch: 358/400 | Train Loss: 1.3251232303446159e-05 | Time: 4.42s (cuda) | LR: 9.67140655691705e-08


 90%|████████▉ | 359/400 [30:50<03:24,  4.99s/it]

Epoch: 359/400 | Train Loss: 1.324923778156517e-05 | Time: 5.5s (cuda) | LR: 9.67140655691705e-08


 90%|█████████ | 360/400 [30:54<03:13,  4.83s/it]

Epoch: 360/400 | Train Loss: 1.3245991794974543e-05 | Time: 4.45s (cuda) | LR: 9.67140655691705e-08


 90%|█████████ | 361/400 [30:59<03:03,  4.70s/it]

Epoch: 361/400 | Train Loss: 1.3242398381407838e-05 | Time: 4.39s (cuda) | LR: 7.73712524553364e-08


 90%|█████████ | 362/400 [31:04<03:06,  4.92s/it]

Epoch: 362/400 | Train Loss: 1.3188227967475541e-05 | Time: 5.43s (cuda) | LR: 7.73712524553364e-08


 91%|█████████ | 363/400 [31:08<02:56,  4.78s/it]

Epoch: 363/400 | Train Loss: 1.3186877367843408e-05 | Time: 4.46s (cuda) | LR: 7.73712524553364e-08


 91%|█████████ | 364/400 [31:13<02:54,  4.85s/it]

Epoch: 364/400 | Train Loss: 1.3183587725507095e-05 | Time: 5.0s (cuda) | LR: 7.73712524553364e-08


 91%|█████████▏| 365/400 [31:18<02:50,  4.86s/it]

Epoch: 365/400 | Train Loss: 1.3180052519601304e-05 | Time: 4.89s (cuda) | LR: 7.73712524553364e-08


 92%|█████████▏| 366/400 [31:23<02:39,  4.70s/it]

Epoch: 366/400 | Train Loss: 1.3180902897147462e-05 | Time: 4.34s (cuda) | LR: 6.189700196426913e-08


 92%|█████████▏| 367/400 [31:28<02:41,  4.88s/it]

Epoch: 367/400 | Train Loss: 1.3134911569068208e-05 | Time: 5.3s (cuda) | LR: 6.189700196426913e-08


 92%|█████████▏| 368/400 [31:32<02:31,  4.74s/it]

Epoch: 368/400 | Train Loss: 1.313233224209398e-05 | Time: 4.39s (cuda) | LR: 6.189700196426913e-08


 92%|█████████▏| 369/400 [31:37<02:24,  4.65s/it]

Epoch: 369/400 | Train Loss: 1.3131270861777011e-05 | Time: 4.44s (cuda) | LR: 6.189700196426913e-08


 92%|█████████▎| 370/400 [31:42<02:25,  4.85s/it]

Epoch: 370/400 | Train Loss: 1.3128430509823374e-05 | Time: 5.33s (cuda) | LR: 6.189700196426913e-08


 93%|█████████▎| 371/400 [31:47<02:16,  4.71s/it]

Epoch: 371/400 | Train Loss: 1.3127450984029565e-05 | Time: 4.38s (cuda) | LR: 6.189700196426913e-08


 93%|█████████▎| 372/400 [31:52<02:18,  4.96s/it]

Epoch: 372/400 | Train Loss: 1.312567837885581e-05 | Time: 5.54s (cuda) | LR: 6.189700196426913e-08


 93%|█████████▎| 373/400 [31:57<02:09,  4.80s/it]

Epoch: 373/400 | Train Loss: 1.3125646546541248e-05 | Time: 4.41s (cuda) | LR: 6.189700196426913e-08


 94%|█████████▎| 374/400 [32:01<02:01,  4.67s/it]

Epoch: 374/400 | Train Loss: 1.3123932149028406e-05 | Time: 4.38s (cuda) | LR: 6.189700196426913e-08


 94%|█████████▍| 375/400 [32:06<02:01,  4.85s/it]

Epoch: 375/400 | Train Loss: 1.3121545634930953e-05 | Time: 5.26s (cuda) | LR: 6.189700196426913e-08


 94%|█████████▍| 376/400 [32:11<01:53,  4.72s/it]

Epoch: 376/400 | Train Loss: 1.3120425137458369e-05 | Time: 4.42s (cuda) | LR: 6.189700196426913e-08


 94%|█████████▍| 377/400 [32:16<01:50,  4.80s/it]

Epoch: 377/400 | Train Loss: 1.3119139111950062e-05 | Time: 5.0s (cuda) | LR: 6.189700196426913e-08


 94%|█████████▍| 378/400 [32:21<01:47,  4.87s/it]

Epoch: 378/400 | Train Loss: 1.3117432899889536e-05 | Time: 5.0s (cuda) | LR: 6.189700196426913e-08


 95%|█████████▍| 379/400 [32:25<01:39,  4.75s/it]

Epoch: 379/400 | Train Loss: 1.3115542969899252e-05 | Time: 4.49s (cuda) | LR: 6.189700196426913e-08


 95%|█████████▌| 380/400 [32:30<01:38,  4.93s/it]

Epoch: 380/400 | Train Loss: 1.31146953208372e-05 | Time: 5.32s (cuda) | LR: 6.189700196426913e-08


 95%|█████████▌| 381/400 [32:35<01:30,  4.77s/it]

Epoch: 381/400 | Train Loss: 1.3112425222061574e-05 | Time: 4.42s (cuda) | LR: 4.9517601571415304e-08


 96%|█████████▌| 382/400 [32:39<01:24,  4.67s/it]

Epoch: 382/400 | Train Loss: 1.3077489711577073e-05 | Time: 4.44s (cuda) | LR: 4.9517601571415304e-08


 96%|█████████▌| 383/400 [32:45<01:22,  4.87s/it]

Epoch: 383/400 | Train Loss: 1.307413367612753e-05 | Time: 5.33s (cuda) | LR: 4.9517601571415304e-08


 96%|█████████▌| 384/400 [32:49<01:16,  4.79s/it]

Epoch: 384/400 | Train Loss: 1.3072325600660406e-05 | Time: 4.59s (cuda) | LR: 4.9517601571415304e-08


 96%|█████████▋| 385/400 [32:55<01:14,  4.97s/it]

Epoch: 385/400 | Train Loss: 1.3072407455183566e-05 | Time: 5.39s (cuda) | LR: 3.9614081257132246e-08


 96%|█████████▋| 386/400 [32:59<01:07,  4.82s/it]

Epoch: 386/400 | Train Loss: 1.3041919373790734e-05 | Time: 4.46s (cuda) | LR: 3.9614081257132246e-08


 97%|█████████▋| 387/400 [33:03<01:00,  4.68s/it]

Epoch: 387/400 | Train Loss: 1.3041277270531282e-05 | Time: 4.35s (cuda) | LR: 3.9614081257132246e-08


 97%|█████████▋| 388/400 [33:09<00:58,  4.90s/it]

Epoch: 388/400 | Train Loss: 1.3040287740295753e-05 | Time: 5.41s (cuda) | LR: 3.9614081257132246e-08


 97%|█████████▋| 389/400 [33:13<00:52,  4.74s/it]

Epoch: 389/400 | Train Loss: 1.3038544238952454e-05 | Time: 4.38s (cuda) | LR: 3.9614081257132246e-08


 98%|█████████▊| 390/400 [33:18<00:47,  4.77s/it]

Epoch: 390/400 | Train Loss: 1.3038204087933991e-05 | Time: 4.82s (cuda) | LR: 3.9614081257132246e-08


 98%|█████████▊| 391/400 [33:23<00:43,  4.87s/it]

Epoch: 391/400 | Train Loss: 1.303718818235211e-05 | Time: 5.11s (cuda) | LR: 3.9614081257132246e-08


 98%|█████████▊| 392/400 [33:28<00:37,  4.72s/it]

Epoch: 392/400 | Train Loss: 1.3036416930845007e-05 | Time: 4.36s (cuda) | LR: 3.9614081257132246e-08


 98%|█████████▊| 393/400 [33:33<00:34,  4.93s/it]

Epoch: 393/400 | Train Loss: 1.303506905969698e-05 | Time: 5.42s (cuda) | LR: 3.9614081257132246e-08


 98%|█████████▊| 394/400 [33:37<00:28,  4.77s/it]

Epoch: 394/400 | Train Loss: 1.3033815775997937e-05 | Time: 4.41s (cuda) | LR: 3.9614081257132246e-08


 99%|█████████▉| 395/400 [33:42<00:23,  4.68s/it]

Epoch: 395/400 | Train Loss: 1.3033431059739087e-05 | Time: 4.45s (cuda) | LR: 3.9614081257132246e-08


 99%|█████████▉| 396/400 [33:47<00:19,  4.91s/it]

Epoch: 396/400 | Train Loss: 1.3032112292421516e-05 | Time: 5.45s (cuda) | LR: 3.9614081257132246e-08


 99%|█████████▉| 397/400 [33:52<00:14,  4.81s/it]

Epoch: 397/400 | Train Loss: 1.3031471098656766e-05 | Time: 4.58s (cuda) | LR: 3.9614081257132246e-08


100%|█████████▉| 398/400 [33:57<00:09,  5.00s/it]

Epoch: 398/400 | Train Loss: 1.3030092304688878e-05 | Time: 5.44s (cuda) | LR: 3.9614081257132246e-08


100%|█████████▉| 399/400 [34:02<00:04,  4.84s/it]

Epoch: 399/400 | Train Loss: 1.3028922694502398e-05 | Time: 4.47s (cuda) | LR: 3.9614081257132246e-08


100%|██████████| 400/400 [34:06<00:00,  5.12s/it]

Epoch: 400/400 | Train Loss: 1.3027721252001356e-05 | Time: 4.5s (cuda) | LR: 3.9614081257132246e-08

Epoch with Least Loss: 400 | Loss: 1.3027721e-05 






In [34]:
# Initialize prediction lists
prediction_list = [[] for _ in range(total_vars)]

# Inference loop
with torch.no_grad():
    for i in range(0, torch_coords.shape[0], group_size):
        coords = torch_coords[i:min(i + group_size, torch_coords.shape[0])].type(torch.float32).to(device)
        vals = model(coords)
        vals = vals.to('cpu')

        for j in range(total_vars):
            prediction_list[j].append(vals[:, j])

# Extract and concatenate predictions
extracted_list = [[] for _ in range(total_vars)]
for i in range(len(prediction_list[0])):
    for j in range(total_vars):
        el = prediction_list[j][i].detach().numpy()
        extracted_list[j].append(el)

for j in range(total_vars):
    extracted_list[j] = np.concatenate(extracted_list[j], dtype='float32')

# Final prediction (normalized)
n_predictions = np.array(extracted_list).T

# Compute PSNR
findMultiVariatePSNR(var_name, total_vars, n_val, n_predictions)

# Compute RMSE
rmse = compute_rmse(n_val, n_predictions)
print("RMSE:", rmse)

Average  PSNR: 53.970339193952
Standard_Deviation  PSNR: 56.11223231085033

Average psnr :  55.041285752401166
RMSE: 0.003593371349644404


In [22]:
# !rm -rf /kaggle/working/*

In [37]:
print(os.path.getsize('/content/models/train_3d_data_200ep_6rb_150n_512bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval20.pth') / (1024 ** 2), 'MB')

1.0534076690673828 MB


In [40]:
# vti saving path
vti_path = args.vti_path
if not os.path.exists(vti_path):
    os.makedirs(vti_path)
# vti name
isMaskPresent=False
mask_arr=[]
vti_name = "teardrop_150_week_8_predicted.vti"
vti_path= ""
makeVTI(data, real_data, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name)

Vti File written successfully at teardrop_150_week_8_predicted.vti
