In [1]:
!pip install vtk



In [1]:
import numpy as np
import torch
from torch import nn, optim
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import seaborn as sns
import matplotlib.pyplot as plt
import vtk
from vtk import *
from vtk.util.numpy_support import vtk_to_numpy
import random
import os
import sys
import time

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Device running:', device)

Device running: cuda


In [3]:
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.in_features = in_features
        # if enable_dropout:
        #     if not self.is_first:
        #         self.dropout = nn.Dropout(dropout_prob)
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)


    def forward(self, x):
        x = self.linear(x)
        # if self.enable_dropout:
        #     if not self.is_first:
        #         x = self.dropout(x)
        return torch.sin(self.omega_0 * x)

In [4]:
class ResidualSineLayer(nn.Module):
    def __init__(self, features, bias=True, ave_first=False, ave_second=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.features = features
        # if enable_dropout:
        #     self.dropout_1 = nn.Dropout(dropout_prob)
        self.linear_1 = nn.Linear(features, features, bias=bias)
        self.linear_2 = nn.Linear(features, features, bias=bias)
        self.weight_1 = .5 if ave_first else 1
        self.weight_2 = .5 if ave_second else 1

        self.init_weights()


    def init_weights(self):
        with torch.no_grad():
            self.linear_1.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)
            self.linear_2.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)

    def forward(self, input):
        linear_1 = self.linear_1(self.weight_1*input)
        # if self.enable_dropout:
        #     linear_1 = self.dropout_1(linear_1)
        sine_1 = torch.sin(self.omega_0 * linear_1)
        sine_2 = torch.sin(self.omega_0 * self.linear_2(sine_1))
        return self.weight_2*(input+sine_2)

In [5]:
class MyResidualSirenNet(nn.Module):
    def __init__(self, obj):
        super(MyResidualSirenNet, self).__init__()
        # self.enable_dropout = obj['enable_dropout']
        # self.dropout_prob = obj['dropout_prob']
        self.Omega_0=30
        self.n_layers = obj['n_layers']
        self.input_dim = obj['dim']
        self.output_dim = obj['total_vars']
        self.neurons_per_layer = obj['n_neurons']
        self.layers = [self.input_dim]
        for i in range(self.n_layers-1):
            self.layers.append(self.neurons_per_layer)
        self.layers.append(self.output_dim)
        self.net_layers = nn.ModuleList()
        for idx in np.arange(self.n_layers):
            layer_in = self.layers[idx]
            layer_out = self.layers[idx+1]
            ## if not the final layer
            if idx != self.n_layers-1:
                ## if first layer
                if idx==0:
                    self.net_layers.append(SineLayer(layer_in,layer_out,bias=True,is_first=idx==0))
                ## if an intermdeiate layer
                else:
                    self.net_layers.append(ResidualSineLayer(layer_in,bias=True,ave_first=idx>1,ave_second=idx==(self.n_layers-2)))
            ## if final layer
            else:
                final_linear = nn.Linear(layer_in,layer_out)
                ## initialize weights for the final layer
                with torch.no_grad():
                    final_linear.weight.uniform_(-np.sqrt(6 / (layer_in)) / self.Omega_0, np.sqrt(6 / (layer_in)) / self.Omega_0)
                self.net_layers.append(final_linear)

    def forward(self,x):
        for net_layer in self.net_layers:
            x = net_layer(x)
        return x

In [6]:
def size_of_network(n_layers, n_neurons, d_in, d_out, is_residual = True):
    # Adding input layer
    layers = [d_in]
    # layers = [3]

    # Adding hidden layers
    layers.extend([n_neurons]*n_layers)
    # layers = [3, 5, 5, 5]

    # Adding output layer
    layers.append(d_out)
    # layers = [3, 5, 5, 5, 1]

    # Number of steps
    n_layers = len(layers)-1
    # n_layers = 5 - 1 = 4

    n_params = 0

    # np.arange(4) = [0, 1, 2, 3]
    for ndx in np.arange(n_layers):

        # number of neurons in below layer
        layer_in = layers[ndx]

        # number of neurons in above layer
        layer_out = layers[ndx+1]

        # max number of neurons in both the layer
        og_layer_in = max(layer_in,layer_out)

        # if lower layer is the input layer
        # or the upper layer is the output layer
        if ndx==0 or ndx==(n_layers-1):
            # Adding weight corresponding to every neuron for every input neuron
            # Adding bias for every neuron in the upper layer
            n_params += ((layer_in+1)*layer_out)

        else:

            # If the layer is residual then proceed as follows as there will be more weights if residual layer is included
            if is_residual:
                # doubt in the following two lines
                n_params += (layer_in*og_layer_in)+og_layer_in
                n_params += (og_layer_in*layer_out)+layer_out

            # if the layer is non residual then simply add number of weights and biases as follows
            else:
                n_params += ((layer_in+1)*layer_out)
            #
        #
    #

    return n_params

In [7]:
def compute_PSNR(arrgt,arr_recon):
    diff = arrgt - arr_recon
    sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
    snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
    return snr

In [8]:
def srs(numOfPoints, valid_pts, percentage, isMaskPresent, mask_array):

    # getting total number of sampled points
    numberOfSampledPoints = int((valid_pts/100) * percentage)

    # storing corner indices in indices variable
    indices = set()

    # As long as we don't get the required amount of sample points keep finding the random numbers
    while(len(indices) < numberOfSampledPoints):
        rp = random.randint(0, numOfPoints-1)
        if isMaskPresent and mask_array[rp] == 0:
            continue
        indices.add(rp)

    # return indices
    return indices

In [9]:
def findMultiVariatePSNR(var_name, total_vars, actual, pred):
    # print('Printing PSNR')
    tot = 0
    psnr_list = []
    for j in range(total_vars):
        psnr = compute_PSNR(actual[:,j], pred[:,j])
        psnr_list.append(psnr)
        tot += psnr
        print(var_name, ' PSNR:', psnr)
    avg_psnr = tot/total_vars
    print('\nAverage psnr : ', avg_psnr)
     #this function is calculating the psnr of final epoch (or whenever it is called) of each variable and then averaging it
     #Thus individual epochs psnr is not calculated

    return psnr_list, avg_psnr

In [10]:
def compute_rmse(actual, predicted):
    mse = np.mean((actual - predicted) ** 2)
    return np.sqrt(mse)

def denormalizeValue(total_vars, to, ref):
    to_arr = np.array(to)
    for i in range(total_vars):
        min_data = np.min(ref[:, i])
        max_data = np.max(ref[:, i])
        to_arr[:, i] = (((to[:, i] * 0.5) + 0.5) * (max_data - min_data)) + min_data
    return to_arr

In [11]:
def makeVTI(data, val, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name, normalizedVersion = False):
    nn_predictions = denormalizeValue(total_vars, n_predictions, val) if not normalizedVersion else n_predictions
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(vti_path + vti_name)
    img = vtkImageData()
    img.CopyStructure(data)
    if not isMaskPresent:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            arr = vtkFloatArray()
            for j in range(n_pts):
                arr.InsertNextValue(temp[j])
            arr.SetName(f)
            img.GetPointData().AddArray(arr)
        # print(img)
        writer.SetInputData(img)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')
    else:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            idx = 0
            arr = vtkFloatArray()
            for j in range(n_pts):
                if(mask_arr[j] == 1):
                    arr.InsertNextValue(temp[idx])
                    idx += 1
                else:
                    arr.InsertNextValue(0.0)
            arr.SetName('p_' + f)
            data.GetPointData().AddArray(arr)
        # print(data)
        writer.SetInputData(data)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')

In [12]:
def getImageData(actual_img, val, n_pts, var_name, isMaskPresent, mask_arr):
    img = vtkImageData()
    img.CopyStructure(actual_img)
    # if isMaskPresent:
    #     img.DeepCopy(actual_img)
    # img.SetDimensions(dim)
    # img.SetOrigin(actual_img.GetOrigin())
    # img.SetSpacing(actual_img.GetSpacing())
    if not isMaskPresent:
        f = var_name
        data = val
        arr = vtkFloatArray()
        for j in range(n_pts):
            arr.InsertNextValue(data[j])
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    else:
        f = var_name
        data = val
        idx = 0
        arr = vtkFloatArray()
        for j in range(n_pts):
            if(mask_arr[j] == 1):
                arr.InsertNextValue(data[idx])
                idx += 1
            else:
                arr.InsertNextValue(0.0)
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    return img

In [26]:
from argparse import Namespace

# Parameters (simulating argparse in a Jupyter Notebook)
args = Namespace(
    n_neurons=150,
    n_layers=6,
    epochs=200,  # Required argument: Set the number of epochs
    batchsize=512,
    lr=0.00005,
    no_decay=False,
    decay_rate=0.8,
    decay_at_interval=True,
    decay_interval=15,
    datapath='/content/Teardrop_Gaussian.vti',  # Required: Set the path to your data
    outpath='./models/',
    exp_path='../logs/',
    modified_data_path='./data/',
    dataset_name='3d_data',  # Required: Set the dataset name
    vti_name='predicted.vti',  # Required: Name of the dataset
    vti_path='./data/'
)

print(args, end='\n\n')

# Assigning parameters to variables
LR = args.lr
BATCH_SIZE = args.batchsize
decay_rate = args.decay_rate
decay_at_equal_interval = args.decay_at_interval

decay = not args.no_decay
MAX_EPOCH = args.epochs

n_neurons = args.n_neurons
n_layers = args.n_layers + 2
decay_interval = args.decay_interval
outpath = args.outpath
exp_path = args.exp_path
datapath = args.datapath
modified_data_path = args.modified_data_path
dataset_name = args.dataset_name
vti_name = args.vti_name
vti_path = args.vti_path

# Displaying the final configuration
print(f"Learning Rate: {LR}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Decay Rate: {decay_rate}")
print(f"Max Epochs: {MAX_EPOCH}")
print(f"Number of Neurons per Layer: {n_neurons}")
print(f"Number of Layers (including input/output): {n_layers}")
print(f"Data Path: {datapath}")
print(f"Output Path: {outpath}")
print(f"Dataset Name: {dataset_name}")
print(f"Vti Name: {vti_name}")

Namespace(n_neurons=150, n_layers=6, epochs=200, batchsize=512, lr=5e-05, no_decay=False, decay_rate=0.8, decay_at_interval=True, decay_interval=15, datapath='/content/Teardrop_Gaussian.vti', outpath='./models/', exp_path='../logs/', modified_data_path='./data/', dataset_name='3d_data', vti_name='predicted.vti', vti_path='./data/')

Learning Rate: 5e-05
Batch Size: 512
Decay Rate: 0.8
Max Epochs: 200
Number of Neurons per Layer: 150
Number of Layers (including input/output): 8
Data Path: /content/Teardrop_Gaussian.vti
Output Path: ./models/
Dataset Name: 3d_data
Vti Name: predicted.vti


In [14]:
# Variable Initialization
var_name = []
total_vars = None  # Number of variables
univariate = None  # True if dataset has one variable, else False
group_size = 5000  # Group size during testing


# Constructing the log file name
log_file = (
    f'train_{dataset_name}_{n_layers-2}rb_{n_neurons}n_{BATCH_SIZE}bs_'
    f'{LR}lr_{decay}decay_{decay_rate}dr_'
    f'{"decayingAtInterval" + str(decay_interval) if decay_at_equal_interval else "decayingWhenLossIncr"}'
)

print(log_file)

train_3d_data_6rb_150n_512bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval15


In [15]:
n_pts = None  # Number of points in the dataset
n_dim = None  # Dimensionality of the data
dim = None  # Other dimension-specific information

print("Decay:", decay)
print(f'Extracting variables from path: {datapath}', end="\n\n")

# Placeholder for data
data_array = []
scalar_data = None

Decay: True
Extracting variables from path: /content/Teardrop_Gaussian.vti



In [16]:
# # Reading values from .vti files
# reader = vtk.vtkXMLImageDataReader()
# reader.SetFileName(datapath)
# reader.Update()

# data = reader.GetOutput()
# scalar_data = data
# pdata = data.GetPointData()
# n_pts = data.GetNumberOfPoints()
# dim = data.GetDimensions()
# n_dim = len(dim)
# total_arr = pdata.GetNumberOfArrays()

# print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

# mask_arr = []
# valid_pts = 0
# var_name = []
# data_array = []

# # Extracting data from the .vti file
# for i in range(total_arr):
#     a_name = pdata.GetArrayName(i)
#     if a_name in ['vtkValidPointMask', 'Swirl']:
#         continue

#     cur_arr = pdata.GetArray(a_name)
#     n_components = cur_arr.GetNumberOfComponents()

#     if n_components == 1:
#         var_name.append(a_name)
#         data_array.append(vtk_to_numpy(cur_arr))
#     else:
#         component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
#         var_name.extend(component_names)
#         for c in range(n_components):
#             c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
#             data_array.append(np.array(c_data))

# valid_pts = n_pts  # Assume all points are valid for simplicity
# total_vars = len(var_name)
# univariate = total_vars == 1

# # Prepare numpy arrays for coordinates and variable values
# cord = np.zeros((valid_pts, n_dim))
# val = np.zeros((valid_pts, total_vars))

# # Store data in numpy arrays
# for i in range(n_pts):
#     pt = scalar_data.GetPoint(i)
#     cord[i, :] = pt
#     val[i, :] = [arr[i] for arr in data_array]

# # Display final information
# print("Total Variables:", total_vars)
# print("Univariate:", univariate)
# print("Coordinates Shape:", cord.shape)
# print("Values Shape:", val.shape)

# Reading values from .vti files
reader = vtk.vtkXMLImageDataReader()
reader.SetFileName(datapath)
reader.Update()

data = reader.GetOutput()
scalar_data = data
pdata = data.GetPointData()
n_pts = data.GetNumberOfPoints()
dim = data.GetDimensions()
n_dim = len(dim)
total_arr = pdata.GetNumberOfArrays()

print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

var_name = []
data_array = []

# Extracting data from the .vti file
for i in range(total_arr):
    a_name = pdata.GetArrayName(i)

    cur_arr = pdata.GetArray(a_name)
    n_components = cur_arr.GetNumberOfComponents()

    if n_components == 1:
        var_name.append(a_name)
        data_array.append(vtk_to_numpy(cur_arr))
    else:
        component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
        var_name.extend(component_names)
        for c in range(n_components):
            c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
            data_array.append(np.array(c_data))

total_vars = len(var_name)
univariate = total_vars == 1

# Prepare numpy arrays for coordinates and variable values
cord = np.zeros((n_pts, n_dim))
val = np.zeros((n_pts, total_vars))

# Store data in numpy arrays
for i in range(n_pts):
    pt = scalar_data.GetPoint(i)
    cord[i, :] = pt
    val[i, :] = [arr[i] for arr in data_array]

# Display final information
print("Total Variables:", total_vars)
print("Univariate:", univariate)
print("Coordinates Shape:", cord.shape)
print("Values Shape:", val.shape)

n_pts: 262144 dim: (64, 64, 64) n_dim: 3 total_arr: 2
Total Variables: 2
Univariate: False
Coordinates Shape: (262144, 3)
Values Shape: (262144, 2)


In [17]:
# # Ensure modified data path exists
# if not os.path.exists(modified_data_path):
#     os.mkdir(modified_data_path)

# Save raw coordinates and values
# np.save(f'{modified_data_path}cord.npy', cord)
# np.save(f'{modified_data_path}val.npy', val)

# # Create copies of non-normalized data
# nn_cord = cord.copy()
# nn_val = val.copy()

# === Separate Normalization for Values ===
# We assume the variable order is:
#   - Means: indices 0,1,2
#   - Std Devs: indices 3,4,5
#   - Weights: indices 6,7,8

# # We'll store normalization parameters so that we can invert normalization later.
# norm_params = {}
# epsilon = 1e-8  # to avoid log(0)

# # Normalize Means to [-1,1] using min–max normalization
# for i in range(3):
#     min_val = np.min(val[:, i])
#     max_val = np.max(val[:, i])
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((val[:, i] - min_val) / (max_val - min_val) - 0.5)

# # Normalize Std Devs: first take log, then min–max to [-1,1]
# for i in range(3, 6):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# # Normalize Weights: take log, then min–max to [-1,1]
# for i in range(6, 9):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# norm_params = {}
real_data=val.copy()
print(np.max(real_data))
# Normalize values between -1 and 1
for i in range(total_vars):
    min_data = np.min(val[:, i])
    max_data = np.max(val[:, i])
    # norm_params[var_name[i]] = (min_data, max_data)
    val[:, i] = 2.0 * ((val[:, i] - min_data) / (max_data - min_data) - 0.5)

# Normalize Coordinates to [-1,1]
for i in range(n_dim):
    # Use (dim[i]-1] so that coordinates go from 0 to dim[i]-1.
    cord[:, i] = 2.0 * (cord[:, i] / (dim[i] - 1) - 0.5)

# # Normalize coordinates between 0 and 1
# for i in range(n_dim):
#     cord[:, i] = cord[:, i] / dim[i]


# # Save normalized values and coordinates
# np.save(f'{modified_data_path}n_cord.npy', cord)
# np.save(f'{modified_data_path}n_val.npy', val)
n_cord = cord.copy()
n_val = val.copy()

# # Reload data for verification
# n_cord = np.load(f'{modified_data_path}n_cord.npy')
# n_val = np.load(f'{modified_data_path}n_val.npy')
# cord = np.load(f'{modified_data_path}cord.npy')
# val = np.load(f'{modified_data_path}val.npy')
means=n_val[:,0]
stds=n_val[:,1]
# Convert normalized data to PyTorch tensors
torch_coords = torch.from_numpy(n_cord)
torch_means = torch.from_numpy(means)
torch_stds =torch.from_numpy(stds)
# Display dataset details
print('Dataset Name:', dataset_name)
print('Total Variables:', total_vars)
print('Variables Name:', var_name, end="\n\n")
print('Total Points in Data:', n_pts)
print('Dimension of the Dataset:', dim)
print('Number of Dimensions:', n_dim)
print('Coordinate Tensor Shape:', torch_coords.shape)
print('Scalar means Values Tensor Shape:', torch_means.shape)
print('Scalar stds Values Tensor Shape:', torch_stds.shape)

print('\n###### Data setup is complete, now starting training ######\n')

161.9956817626953
Dataset Name: 3d_data
Total Variables: 2
Variables Name: ['Average', 'Standard_Deviation']

Total Points in Data: 262144
Dimension of the Dataset: (64, 64, 64)
Number of Dimensions: 3
Coordinate Tensor Shape: torch.Size([262144, 3])
Scalar means Values Tensor Shape: torch.Size([262144])
Scalar stds Values Tensor Shape: torch.Size([262144])

###### Data setup is complete, now starting training ######



In [18]:
# Prepare the DataLoader
train_dataloader_mean = DataLoader(
    TensorDataset(torch_coords, torch_means),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    num_workers=4
)
# Model configuration
obj = {
    'total_vars': 1,
    'dim': n_dim,
    'n_neurons': n_neurons,
    'n_layers': n_layers
}

# Initialize the model, optimizer, and loss function
model_mean = MyResidualSirenNet(obj).to(device)
print(model_mean)

optimizer = optim.Adam(model_mean.parameters(), lr=LR, betas=(0.9, 0.999))
print(optimizer)

criterion = nn.MSELoss()
print(criterion)

# Training configuration summary
print('\nLearning Rate:', LR)
print('Max Epochs:', MAX_EPOCH)
print('Batch Size:', BATCH_SIZE)
print('Number of Hidden Layers:', obj['n_layers'] - 2)
print('Number of Neurons per Layer:', obj['n_neurons'])

if decay:
    print('Decay Rate:', decay_rate)
    if decay_at_equal_interval:
        print(f'Rate decays every {decay_interval} epochs.')
    else:
        print('Rate decays when the current epoch loss is greater than the previous epoch loss.')
else:
    print('No decay!')
print()



MyResidualSirenNet(
  (net_layers): ModuleList(
    (0): SineLayer(
      (linear): Linear(in_features=3, out_features=150, bias=True)
    )
    (1-6): 6 x ResidualSineLayer(
      (linear_1): Linear(in_features=150, out_features=150, bias=True)
      (linear_2): Linear(in_features=150, out_features=150, bias=True)
    )
    (7): Linear(in_features=150, out_features=1, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 5e-05
    maximize: False
    weight_decay: 0
)
MSELoss()

Learning Rate: 5e-05
Max Epochs: 200
Batch Size: 512
Number of Hidden Layers: 6
Number of Neurons per Layer: 150
Decay Rate: 0.8
Rate decays every 15 epochs.



In [48]:
train_loss_list = []
best_epoch = -1
best_loss = 1e8
best_model=""
from tqdm import tqdm
# Ensure the output path exists
if not os.path.exists(outpath):
    os.makedirs(outpath)

# Training loop
for epoch in tqdm(range(MAX_EPOCH)):
    model_mean.train()
    temp_loss_list = []
    start = time.time()

    # Batch-by-batch training
    for X_train, y_train in train_dataloader_mean:
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.float32).to(device)

        if univariate:
            y_train = y_train.squeeze()

        optimizer.zero_grad()
        predictions = model_mean(X_train)
        predictions = predictions.squeeze()
        loss = criterion(predictions, y_train)
        loss.backward()
        optimizer.step()

        # Track batch loss
        temp_loss_list.append(loss.detach().cpu().numpy())

    # Calculate epoch loss
    epoch_loss = np.average(temp_loss_list)

    # Learning rate decay
    if decay:
        if decay_at_equal_interval:
            if epoch >= decay_interval and epoch % decay_interval == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= decay_rate
        # else:
        #     if epoch > 0 and epoch_loss > train_loss_list[-1]:
        #         for param_group in optimizer.param_groups:
        #             param_group['lr'] *= decay_rate
        if epoch > 0 and epoch_loss > train_loss_list[-1]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_rate

    # Track losses and best model
    train_loss_list.append(epoch_loss)
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch+1
        if(best_model==0):
            best_model=model_mean.state_dict()
        else:
            best_model=model_mean.state_dict()

    end = time.time()
    print(
        f"Epoch: {epoch + 1}/{MAX_EPOCH} | Train Loss: {train_loss_list[-1]} | "
        f"Time: {round(end - start, 2)}s ({device}) | LR: {optimizer.param_groups[0]['lr']}"
    )

    # Save model at intervals
    if (epoch + 1) % 50 == 0:
        model_name = (
            f'train_{dataset_name}_{epoch + 1}ep_{n_layers - 2}rb_{n_neurons}n_'
            f'{BATCH_SIZE}bs_{LR}lr_{decay}decay_{decay_rate}dr_'
            f'{"decayingAtInterval" + str(decay_interval)+"mean" if decay_at_equal_interval else "decayingWhenLossIncr"}'
        )
        torch.save(
            {"epoch": epoch + 1, "model_state_dict": model_mean.state_dict()},
            os.path.join(outpath, f'{model_name}_mean.pth')
        )

# Final summary
print('\nEpoch with Least Loss:', best_epoch, '| Loss:', best_loss, '\n')

# Save the final model
model_name = f'siren_compressor'
torch.save(
    {"epoch": MAX_EPOCH, "model_state_dict": model_mean.state_dict()},
    os.path.join(outpath, f'{model_name}_mean.pth')
)
torch.save(
    {"epoch": best_epoch, "model_state_dict": best_model},
    os.path.join(outpath, f'{best_epoch}_mean.pth')
)


  0%|          | 1/200 [00:05<18:37,  5.61s/it]

Epoch: 1/200 | Train Loss: 0.007658324204385281 | Time: 5.61s (cuda) | LR: 5e-05


  1%|          | 2/200 [00:12<21:03,  6.38s/it]

Epoch: 2/200 | Train Loss: 0.003164463210850954 | Time: 6.92s (cuda) | LR: 5e-05


  2%|▏         | 3/200 [00:18<20:03,  6.11s/it]

Epoch: 3/200 | Train Loss: 0.0028839341830462217 | Time: 5.78s (cuda) | LR: 5e-05


  2%|▏         | 4/200 [00:25<21:20,  6.54s/it]

Epoch: 4/200 | Train Loss: 0.00266069732606411 | Time: 7.19s (cuda) | LR: 5e-05


  2%|▏         | 4/200 [00:29<23:55,  7.32s/it]


KeyboardInterrupt: 

In [19]:
# Initialize prediction lists
prediction_list = [[] for _ in range(1)]
total_vars=1
# Inference loop
model_mean = MyResidualSirenNet(obj).to(device)
state_dict = torch.load(os.path.join(outpath, '200_mean.pth'))['model_state_dict']
model_mean.load_state_dict(state_dict)
with torch.no_grad():
    for i in range(0, torch_coords.shape[0], group_size):
        coords = torch_coords[i:min(i + group_size, torch_coords.shape[0])].type(torch.float32).to(device)
        vals = model_mean(coords)
        vals = vals.to('cpu')

        for j in range(total_vars):
            prediction_list[j].append(vals[:, j])

# Extract and concatenate predictions
extracted_list = [[] for _ in range(1)]
for i in range(len(prediction_list[0])):
    for j in range(1):
        el = prediction_list[j][i].detach().numpy()
        extracted_list[j].append(el)

for j in range(1):
    extracted_list[j] = np.concatenate(extracted_list[j], dtype='float32')

# Final prediction (normalized)
n_predictions_means = np.array(extracted_list).T

# Compute PSNR
#findMultiVariatePSNR(var_name[0], total_vars, n_val[:,0], n_predictions_means[:,0])
print("mean",compute_PSNR(n_val[:,0],n_predictions_means[:,0]))
# Compute RMSE
rmse = compute_rmse(n_val[:,0], n_predictions_means[:,0])
print("RMSE:", rmse)

mean 72.7810994165529
RMSE: 0.00045917160629419825


In [24]:
# Prepare the DataLoader
train_dataloader_std= DataLoader(
    TensorDataset(torch_coords, torch_stds),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    num_workers=4
)
# Model configuration
obj = {
    'total_vars': 1,
    'dim': n_dim,
    'n_neurons': n_neurons,
    'n_layers': n_layers
}

# Initialize the model, optimizer, and loss function
model_std = MyResidualSirenNet(obj).to(device)
print(model_mean)

optimizer_std = optim.Adam(model_mean.parameters(), lr=LR, betas=(0.9, 0.999))
print(optimizer)

criterion_std = nn.MSELoss()
print(criterion_std)

# Training configuration summary
print('\nLearning Rate:', LR)
print('Max Epochs:', MAX_EPOCH)
print('Batch Size:', BATCH_SIZE)
print('Number of Hidden Layers:', obj['n_layers'] - 2)
print('Number of Neurons per Layer:', obj['n_neurons'])

if decay:
    print('Decay Rate:', decay_rate)
    if decay_at_equal_interval:
        print(f'Rate decays every {decay_interval} epochs.')
    else:
        print('Rate decays when the current epoch loss is greater than the previous epoch loss.')
else:
    print('No decay!')
print()

MyResidualSirenNet(
  (net_layers): ModuleList(
    (0): SineLayer(
      (linear): Linear(in_features=3, out_features=150, bias=True)
    )
    (1-6): 6 x ResidualSineLayer(
      (linear_1): Linear(in_features=150, out_features=150, bias=True)
      (linear_2): Linear(in_features=150, out_features=150, bias=True)
    )
    (7): Linear(in_features=150, out_features=1, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 5e-05
    maximize: False
    weight_decay: 0
)
MSELoss()

Learning Rate: 5e-05
Max Epochs: 200
Batch Size: 512
Number of Hidden Layers: 6
Number of Neurons per Layer: 150
Decay Rate: 0.8
Rate decays every 15 epochs.





In [69]:
train_loss_list = []
best_epoch = -1
best_loss = 1e8
best_model=""
from tqdm import tqdm
# Ensure the output path exists
if not os.path.exists(outpath):
    os.makedirs(outpath)

# Training loop
for epoch in tqdm(range(MAX_EPOCH)):
    model_std.train()
    temp_loss_list = []
    start = time.time()

    # Batch-by-batch training
    for X_train, y_train in train_dataloader_std:
        X_train = X_train.type(torch.float32).to(device)
        y_train = y_train.type(torch.float32).to(device)

        if univariate:
            y_train = y_train.squeeze()

        optimizer_std.zero_grad()
        predictions = model_mean(X_train)
        predictions = predictions.squeeze()
        loss = criterion_std(predictions, y_train)
        loss.backward()
        optimizer.step()

        # Track batch loss
        temp_loss_list.append(loss.detach().cpu().numpy())

    # Calculate epoch loss
    epoch_loss = np.average(temp_loss_list)

    # Learning rate decay
    if decay:
        if decay_at_equal_interval:
            if epoch >= decay_interval and epoch % decay_interval == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= decay_rate
        # else:
        #     if epoch > 0 and epoch_loss > train_loss_list[-1]:
        #         for param_group in optimizer.param_groups:
        #             param_group['lr'] *= decay_rate
        if epoch > 0 and epoch_loss > train_loss_list[-1]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_rate

    # Track losses and best model
    train_loss_list.append(epoch_loss)
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch+1
        if(best_model==0):
            best_model=model_mean.state_dict()
        else:
            best_model=model_mean.state_dict()

    end = time.time()
    print(
        f"Epoch: {epoch + 1}/{MAX_EPOCH} | Train Loss: {train_loss_list[-1]} | "
        f"Time: {round(end - start, 2)}s ({device}) | LR: {optimizer.param_groups[0]['lr']}"
    )

    # Save model at intervals
    if (epoch + 1) % 50 == 0:
        model_name = (
            f'train_{dataset_name}_{epoch + 1}ep_{n_layers - 2}rb_{n_neurons}n_'
            f'{BATCH_SIZE}bs_{LR}lr_{decay}decay_{decay_rate}dr_'
            f'{"decayingAtInterval" + str(decay_interval)+"std" if decay_at_equal_interval else "decayingWhenLossIncr"}'
        )
        torch.save(
            {"epoch": epoch + 1, "model_state_dict": model_mean.state_dict()},
            os.path.join(outpath, f'{model_name}_std.pth')
        )

# Final summary
print('\nEpoch with Least Loss:', best_epoch, '| Loss:', best_loss, '\n')

# Save the final model
model_name = f'siren_compressor'
torch.save(
    {"epoch": MAX_EPOCH, "model_state_dict": model_mean.state_dict()},
    os.path.join(outpath, f'{model_name}_std.pth')
)
torch.save(
    {"epoch": best_epoch, "model_state_dict": best_model},
    os.path.join(outpath, f'{best_epoch}_std.pth')
)


  0%|          | 1/200 [00:07<24:17,  7.32s/it]

Epoch: 1/200 | Train Loss: 0.009194845333695412 | Time: 7.32s (cuda) | LR: 5e-05


  1%|          | 2/200 [00:13<22:01,  6.67s/it]

Epoch: 2/200 | Train Loss: 0.0028720130212605 | Time: 6.22s (cuda) | LR: 5e-05


  2%|▏         | 3/200 [00:20<22:06,  6.74s/it]

Epoch: 3/200 | Train Loss: 0.002567930379882455 | Time: 6.81s (cuda) | LR: 5e-05


  2%|▏         | 4/200 [00:26<20:48,  6.37s/it]

Epoch: 4/200 | Train Loss: 0.0023847653064876795 | Time: 5.8s (cuda) | LR: 5e-05


  2%|▎         | 5/200 [00:32<21:11,  6.52s/it]

Epoch: 5/200 | Train Loss: 0.002256781095638871 | Time: 6.8s (cuda) | LR: 5e-05


  3%|▎         | 6/200 [00:38<20:15,  6.26s/it]

Epoch: 6/200 | Train Loss: 0.002154187997803092 | Time: 5.76s (cuda) | LR: 5e-05


  4%|▎         | 7/200 [00:45<20:40,  6.43s/it]

Epoch: 7/200 | Train Loss: 0.002080448204651475 | Time: 6.76s (cuda) | LR: 5e-05


  4%|▍         | 8/200 [00:51<19:51,  6.21s/it]

Epoch: 8/200 | Train Loss: 0.002045626752078533 | Time: 5.74s (cuda) | LR: 5e-05


  4%|▍         | 9/200 [00:59<22:09,  6.96s/it]

Epoch: 9/200 | Train Loss: 0.001972614787518978 | Time: 8.61s (cuda) | LR: 5e-05


  5%|▌         | 10/200 [01:05<20:40,  6.53s/it]

Epoch: 10/200 | Train Loss: 0.0019249295582994819 | Time: 5.56s (cuda) | LR: 5e-05


  6%|▌         | 11/200 [01:12<20:44,  6.58s/it]

Epoch: 11/200 | Train Loss: 0.0018737944774329662 | Time: 6.71s (cuda) | LR: 5e-05


  6%|▌         | 12/200 [01:17<19:43,  6.29s/it]

Epoch: 12/200 | Train Loss: 0.0018399524269625545 | Time: 5.63s (cuda) | LR: 5e-05


  6%|▋         | 13/200 [01:24<20:08,  6.46s/it]

Epoch: 13/200 | Train Loss: 0.001806402113288641 | Time: 6.84s (cuda) | LR: 5e-05


  7%|▋         | 14/200 [01:30<19:17,  6.23s/it]

Epoch: 14/200 | Train Loss: 0.0017818710766732693 | Time: 5.68s (cuda) | LR: 5e-05


  8%|▊         | 15/200 [01:36<19:14,  6.24s/it]

Epoch: 15/200 | Train Loss: 0.0017433357425034046 | Time: 6.28s (cuda) | LR: 5e-05


  8%|▊         | 16/200 [01:42<18:51,  6.15s/it]

Epoch: 16/200 | Train Loss: 0.0016952555160969496 | Time: 5.93s (cuda) | LR: 4e-05


  8%|▊         | 17/200 [01:48<19:01,  6.24s/it]

Epoch: 17/200 | Train Loss: 0.0014271930558606982 | Time: 6.44s (cuda) | LR: 4e-05


  9%|▉         | 18/200 [01:56<19:50,  6.54s/it]

Epoch: 18/200 | Train Loss: 0.001263576908968389 | Time: 7.23s (cuda) | LR: 4e-05


 10%|▉         | 19/200 [02:02<19:27,  6.45s/it]

Epoch: 19/200 | Train Loss: 0.0012481918092817068 | Time: 6.24s (cuda) | LR: 4e-05


 10%|█         | 20/200 [02:08<19:05,  6.36s/it]

Epoch: 20/200 | Train Loss: 0.0012059148866683245 | Time: 6.16s (cuda) | LR: 4e-05


 10%|█         | 21/200 [02:14<18:41,  6.26s/it]

Epoch: 21/200 | Train Loss: 0.0011958950199186802 | Time: 6.03s (cuda) | LR: 4e-05


 11%|█         | 22/200 [02:20<18:38,  6.28s/it]

Epoch: 22/200 | Train Loss: 0.001171864802017808 | Time: 6.32s (cuda) | LR: 4e-05


 12%|█▏        | 23/200 [02:27<18:26,  6.25s/it]

Epoch: 23/200 | Train Loss: 0.0011311592534184456 | Time: 6.18s (cuda) | LR: 4e-05


 12%|█▏        | 24/200 [02:33<18:28,  6.30s/it]

Epoch: 24/200 | Train Loss: 0.001101246802136302 | Time: 6.41s (cuda) | LR: 4e-05


 12%|█▎        | 25/200 [02:39<17:51,  6.12s/it]

Epoch: 25/200 | Train Loss: 0.0010757476557046175 | Time: 5.71s (cuda) | LR: 4e-05


 13%|█▎        | 26/200 [02:45<18:19,  6.32s/it]

Epoch: 26/200 | Train Loss: 0.0010798744624480605 | Time: 6.77s (cuda) | LR: 3.2000000000000005e-05


 14%|█▎        | 27/200 [02:51<17:48,  6.18s/it]

Epoch: 27/200 | Train Loss: 0.0008635143749415874 | Time: 5.84s (cuda) | LR: 3.2000000000000005e-05


 14%|█▍        | 28/200 [02:58<18:27,  6.44s/it]

Epoch: 28/200 | Train Loss: 0.0007373134139925241 | Time: 7.04s (cuda) | LR: 3.2000000000000005e-05


 14%|█▍        | 29/200 [03:04<17:34,  6.17s/it]

Epoch: 29/200 | Train Loss: 0.0007226869929581881 | Time: 5.54s (cuda) | LR: 3.2000000000000005e-05


 15%|█▌        | 30/200 [03:11<18:07,  6.39s/it]

Epoch: 30/200 | Train Loss: 0.0007169503951445222 | Time: 6.92s (cuda) | LR: 3.2000000000000005e-05


 16%|█▌        | 31/200 [03:17<17:24,  6.18s/it]

Epoch: 31/200 | Train Loss: 0.0007171737961471081 | Time: 5.68s (cuda) | LR: 2.0480000000000007e-05


 16%|█▌        | 32/200 [03:23<17:49,  6.36s/it]

Epoch: 32/200 | Train Loss: 0.0005033459165133536 | Time: 6.79s (cuda) | LR: 2.0480000000000007e-05


 16%|█▋        | 33/200 [03:29<17:11,  6.18s/it]

Epoch: 33/200 | Train Loss: 0.00037924444768577814 | Time: 5.75s (cuda) | LR: 2.0480000000000007e-05


 17%|█▋        | 34/200 [03:36<18:00,  6.51s/it]

Epoch: 34/200 | Train Loss: 0.0003605893871281296 | Time: 7.27s (cuda) | LR: 2.0480000000000007e-05


 18%|█▊        | 35/200 [03:43<17:48,  6.48s/it]

Epoch: 35/200 | Train Loss: 0.0003632013394962996 | Time: 6.41s (cuda) | LR: 1.6384000000000008e-05


 18%|█▊        | 36/200 [03:50<18:03,  6.61s/it]

Epoch: 36/200 | Train Loss: 0.0003051134408451617 | Time: 6.9s (cuda) | LR: 1.6384000000000008e-05


 18%|█▊        | 37/200 [03:55<17:14,  6.35s/it]

Epoch: 37/200 | Train Loss: 0.00026360899209976196 | Time: 5.75s (cuda) | LR: 1.6384000000000008e-05


 19%|█▉        | 38/200 [04:02<17:28,  6.47s/it]

Epoch: 38/200 | Train Loss: 0.00025372434174641967 | Time: 6.75s (cuda) | LR: 1.6384000000000008e-05


 20%|█▉        | 39/200 [04:08<16:39,  6.21s/it]

Epoch: 39/200 | Train Loss: 0.0002518624532967806 | Time: 5.59s (cuda) | LR: 1.6384000000000008e-05


 20%|██        | 40/200 [04:15<17:04,  6.40s/it]

Epoch: 40/200 | Train Loss: 0.0002549409109633416 | Time: 6.86s (cuda) | LR: 1.3107200000000007e-05


 20%|██        | 41/200 [04:20<16:22,  6.18s/it]

Epoch: 41/200 | Train Loss: 0.0002009312156587839 | Time: 5.66s (cuda) | LR: 1.3107200000000007e-05


 21%|██        | 42/200 [04:27<16:56,  6.43s/it]

Epoch: 42/200 | Train Loss: 0.00017993434448726475 | Time: 7.02s (cuda) | LR: 1.3107200000000007e-05


 22%|██▏       | 43/200 [04:33<16:17,  6.22s/it]

Epoch: 43/200 | Train Loss: 0.00017502898117527366 | Time: 5.73s (cuda) | LR: 1.3107200000000007e-05


 22%|██▏       | 44/200 [04:40<16:46,  6.45s/it]

Epoch: 44/200 | Train Loss: 0.00017678514996077865 | Time: 6.98s (cuda) | LR: 1.0485760000000006e-05


 22%|██▎       | 45/200 [04:46<16:08,  6.25s/it]

Epoch: 45/200 | Train Loss: 0.00013968873827252537 | Time: 5.77s (cuda) | LR: 1.0485760000000006e-05


 23%|██▎       | 46/200 [04:53<16:26,  6.41s/it]

Epoch: 46/200 | Train Loss: 0.00012119787425035611 | Time: 6.77s (cuda) | LR: 8.388608000000005e-06


 24%|██▎       | 47/200 [04:58<15:51,  6.22s/it]

Epoch: 47/200 | Train Loss: 0.00010115245095221326 | Time: 5.78s (cuda) | LR: 8.388608000000005e-06


 24%|██▍       | 48/200 [05:05<16:17,  6.43s/it]

Epoch: 48/200 | Train Loss: 9.331183537142351e-05 | Time: 6.93s (cuda) | LR: 8.388608000000005e-06


 24%|██▍       | 49/200 [05:11<15:37,  6.21s/it]

Epoch: 49/200 | Train Loss: 9.314632916357368e-05 | Time: 5.69s (cuda) | LR: 8.388608000000005e-06


 25%|██▌       | 50/200 [05:18<15:48,  6.32s/it]

Epoch: 50/200 | Train Loss: 9.266989945899695e-05 | Time: 6.58s (cuda) | LR: 8.388608000000005e-06


 26%|██▌       | 51/200 [05:23<15:21,  6.18s/it]

Epoch: 51/200 | Train Loss: 9.248766582459211e-05 | Time: 5.85s (cuda) | LR: 8.388608000000005e-06


 26%|██▌       | 52/200 [05:30<15:21,  6.23s/it]

Epoch: 52/200 | Train Loss: 9.111529652727768e-05 | Time: 6.33s (cuda) | LR: 8.388608000000005e-06


 26%|██▋       | 53/200 [05:36<15:14,  6.22s/it]

Epoch: 53/200 | Train Loss: 8.882123074727133e-05 | Time: 6.2s (cuda) | LR: 8.388608000000005e-06


 27%|██▋       | 54/200 [05:42<14:53,  6.12s/it]

Epoch: 54/200 | Train Loss: 8.706046355655417e-05 | Time: 5.88s (cuda) | LR: 8.388608000000005e-06


 28%|██▊       | 55/200 [05:48<15:01,  6.22s/it]

Epoch: 55/200 | Train Loss: 8.519968832843006e-05 | Time: 6.45s (cuda) | LR: 8.388608000000005e-06


 28%|██▊       | 56/200 [05:54<14:29,  6.04s/it]

Epoch: 56/200 | Train Loss: 8.34684178698808e-05 | Time: 5.63s (cuda) | LR: 8.388608000000005e-06


 28%|██▊       | 57/200 [06:01<14:58,  6.28s/it]

Epoch: 57/200 | Train Loss: 8.186904597096145e-05 | Time: 6.84s (cuda) | LR: 8.388608000000005e-06


 29%|██▉       | 58/200 [06:06<14:21,  6.06s/it]

Epoch: 58/200 | Train Loss: 8.040765533223748e-05 | Time: 5.55s (cuda) | LR: 8.388608000000005e-06


 30%|██▉       | 59/200 [06:13<14:39,  6.24s/it]

Epoch: 59/200 | Train Loss: 7.892186113167554e-05 | Time: 6.65s (cuda) | LR: 8.388608000000005e-06


 30%|███       | 60/200 [06:18<14:03,  6.02s/it]

Epoch: 60/200 | Train Loss: 7.838320743758231e-05 | Time: 5.51s (cuda) | LR: 8.388608000000005e-06


 30%|███       | 61/200 [06:25<14:26,  6.23s/it]

Epoch: 61/200 | Train Loss: 7.728775381110609e-05 | Time: 6.73s (cuda) | LR: 6.7108864000000044e-06


 31%|███       | 62/200 [06:31<13:59,  6.08s/it]

Epoch: 62/200 | Train Loss: 6.098553421907127e-05 | Time: 5.74s (cuda) | LR: 6.7108864000000044e-06


 32%|███▏      | 63/200 [06:38<14:18,  6.27s/it]

Epoch: 63/200 | Train Loss: 5.2002393204020336e-05 | Time: 6.69s (cuda) | LR: 6.7108864000000044e-06


 32%|███▏      | 64/200 [06:43<13:38,  6.02s/it]

Epoch: 64/200 | Train Loss: 5.253678682493046e-05 | Time: 5.44s (cuda) | LR: 5.368709120000004e-06


 32%|███▎      | 65/200 [06:50<13:55,  6.19s/it]

Epoch: 65/200 | Train Loss: 4.456890019355342e-05 | Time: 6.58s (cuda) | LR: 5.368709120000004e-06


 33%|███▎      | 66/200 [06:55<13:24,  6.00s/it]

Epoch: 66/200 | Train Loss: 4.008863470517099e-05 | Time: 5.57s (cuda) | LR: 5.368709120000004e-06


 34%|███▎      | 67/200 [07:02<13:54,  6.27s/it]

Epoch: 67/200 | Train Loss: 4.040004932903685e-05 | Time: 6.9s (cuda) | LR: 4.294967296000004e-06


 34%|███▍      | 68/200 [07:08<13:16,  6.04s/it]

Epoch: 68/200 | Train Loss: 3.446186019573361e-05 | Time: 5.48s (cuda) | LR: 4.294967296000004e-06


 34%|███▍      | 69/200 [07:14<13:19,  6.11s/it]

Epoch: 69/200 | Train Loss: 3.1503106583841145e-05 | Time: 6.27s (cuda) | LR: 4.294967296000004e-06


 35%|███▌      | 70/200 [07:20<13:00,  6.00s/it]

Epoch: 70/200 | Train Loss: 3.1692306947661564e-05 | Time: 5.76s (cuda) | LR: 3.4359738368000033e-06


 36%|███▌      | 71/200 [07:26<12:51,  5.98s/it]

Epoch: 71/200 | Train Loss: 2.7369529561838135e-05 | Time: 5.94s (cuda) | LR: 3.4359738368000033e-06


 36%|███▌      | 72/200 [07:32<12:58,  6.08s/it]

Epoch: 72/200 | Train Loss: 2.5489687686786056e-05 | Time: 6.31s (cuda) | LR: 3.4359738368000033e-06


 36%|███▋      | 73/200 [07:38<12:41,  6.00s/it]

Epoch: 73/200 | Train Loss: 2.529530320316553e-05 | Time: 5.79s (cuda) | LR: 3.4359738368000033e-06


 37%|███▋      | 74/200 [07:44<12:56,  6.16s/it]

Epoch: 74/200 | Train Loss: 2.5374876713613048e-05 | Time: 6.54s (cuda) | LR: 2.7487790694400027e-06


 38%|███▊      | 75/200 [07:50<12:27,  5.98s/it]

Epoch: 75/200 | Train Loss: 2.1933414245722815e-05 | Time: 5.54s (cuda) | LR: 2.7487790694400027e-06


 38%|███▊      | 76/200 [07:56<12:48,  6.19s/it]

Epoch: 76/200 | Train Loss: 2.046031477220822e-05 | Time: 6.7s (cuda) | LR: 2.1990232555520023e-06


 38%|███▊      | 77/200 [08:02<12:22,  6.03s/it]

Epoch: 77/200 | Train Loss: 1.8081935195368715e-05 | Time: 5.66s (cuda) | LR: 2.1990232555520023e-06


 39%|███▉      | 78/200 [08:09<12:36,  6.20s/it]

Epoch: 78/200 | Train Loss: 1.74171946127899e-05 | Time: 6.6s (cuda) | LR: 2.1990232555520023e-06


 40%|███▉      | 79/200 [08:14<12:05,  5.99s/it]

Epoch: 79/200 | Train Loss: 1.7379083146806806e-05 | Time: 5.5s (cuda) | LR: 2.1990232555520023e-06


 40%|████      | 80/200 [08:21<12:23,  6.19s/it]

Epoch: 80/200 | Train Loss: 1.726246773614548e-05 | Time: 6.66s (cuda) | LR: 2.1990232555520023e-06


 40%|████      | 81/200 [08:26<11:53,  6.00s/it]

Epoch: 81/200 | Train Loss: 1.7056248907465488e-05 | Time: 5.53s (cuda) | LR: 2.1990232555520023e-06


 41%|████      | 82/200 [08:33<12:11,  6.20s/it]

Epoch: 82/200 | Train Loss: 1.6797839634818956e-05 | Time: 6.67s (cuda) | LR: 2.1990232555520023e-06


 42%|████▏     | 83/200 [08:39<11:40,  5.99s/it]

Epoch: 83/200 | Train Loss: 1.6604648408247158e-05 | Time: 5.49s (cuda) | LR: 2.1990232555520023e-06


 42%|████▏     | 84/200 [08:45<11:56,  6.17s/it]

Epoch: 84/200 | Train Loss: 1.6378024156438187e-05 | Time: 6.61s (cuda) | LR: 2.1990232555520023e-06


 42%|████▎     | 85/200 [08:51<11:24,  5.96s/it]

Epoch: 85/200 | Train Loss: 1.6081434296211228e-05 | Time: 5.44s (cuda) | LR: 2.1990232555520023e-06


 43%|████▎     | 86/200 [08:57<11:26,  6.02s/it]

Epoch: 86/200 | Train Loss: 1.582074946782086e-05 | Time: 6.18s (cuda) | LR: 2.1990232555520023e-06


 44%|████▎     | 87/200 [09:03<11:16,  5.99s/it]

Epoch: 87/200 | Train Loss: 1.560833516123239e-05 | Time: 5.9s (cuda) | LR: 2.1990232555520023e-06


 44%|████▍     | 88/200 [09:08<10:54,  5.85s/it]

Epoch: 88/200 | Train Loss: 1.5363335478468798e-05 | Time: 5.51s (cuda) | LR: 2.1990232555520023e-06


 44%|████▍     | 89/200 [09:15<11:03,  5.98s/it]

Epoch: 89/200 | Train Loss: 1.5172907296800986e-05 | Time: 6.29s (cuda) | LR: 2.1990232555520023e-06


 45%|████▌     | 90/200 [09:20<10:36,  5.78s/it]

Epoch: 90/200 | Train Loss: 1.4965848095016554e-05 | Time: 5.32s (cuda) | LR: 2.1990232555520023e-06


 46%|████▌     | 91/200 [09:26<10:52,  5.99s/it]

Epoch: 91/200 | Train Loss: 1.4686915164929815e-05 | Time: 6.47s (cuda) | LR: 1.7592186044416019e-06


 46%|████▌     | 92/200 [09:32<10:29,  5.83s/it]

Epoch: 92/200 | Train Loss: 1.2878527741122525e-05 | Time: 5.47s (cuda) | LR: 1.7592186044416019e-06


 46%|████▋     | 93/200 [09:38<10:40,  5.99s/it]

Epoch: 93/200 | Train Loss: 1.2308186342124827e-05 | Time: 6.35s (cuda) | LR: 1.7592186044416019e-06


 47%|████▋     | 94/200 [09:43<10:13,  5.79s/it]

Epoch: 94/200 | Train Loss: 1.2218693882459775e-05 | Time: 5.31s (cuda) | LR: 1.7592186044416019e-06


 48%|████▊     | 95/200 [09:50<10:28,  5.98s/it]

Epoch: 95/200 | Train Loss: 1.2180093108327128e-05 | Time: 6.44s (cuda) | LR: 1.7592186044416019e-06


 48%|████▊     | 96/200 [09:55<10:01,  5.79s/it]

Epoch: 96/200 | Train Loss: 1.2073998732375912e-05 | Time: 5.32s (cuda) | LR: 1.7592186044416019e-06


 48%|████▊     | 97/200 [10:02<10:11,  5.94s/it]

Epoch: 97/200 | Train Loss: 1.192901618196629e-05 | Time: 6.29s (cuda) | LR: 1.7592186044416019e-06


 49%|████▉     | 98/200 [10:07<09:54,  5.83s/it]

Epoch: 98/200 | Train Loss: 1.1789223208324984e-05 | Time: 5.57s (cuda) | LR: 1.7592186044416019e-06


 50%|████▉     | 99/200 [10:12<09:31,  5.66s/it]

Epoch: 99/200 | Train Loss: 1.164587683888385e-05 | Time: 5.25s (cuda) | LR: 1.7592186044416019e-06


 50%|█████     | 100/200 [10:19<09:46,  5.87s/it]

Epoch: 100/200 | Train Loss: 1.1499580068630166e-05 | Time: 6.36s (cuda) | LR: 1.7592186044416019e-06


 50%|█████     | 101/200 [10:24<09:21,  5.67s/it]

Epoch: 101/200 | Train Loss: 1.1378539056750014e-05 | Time: 5.22s (cuda) | LR: 1.7592186044416019e-06


 51%|█████     | 102/200 [10:31<09:42,  5.94s/it]

Epoch: 102/200 | Train Loss: 1.1186569281562697e-05 | Time: 6.56s (cuda) | LR: 1.7592186044416019e-06


 52%|█████▏    | 103/200 [10:36<09:19,  5.76s/it]

Epoch: 103/200 | Train Loss: 1.1100564734078944e-05 | Time: 5.35s (cuda) | LR: 1.7592186044416019e-06


 52%|█████▏    | 104/200 [10:42<09:33,  5.98s/it]

Epoch: 104/200 | Train Loss: 1.0982342246279586e-05 | Time: 6.48s (cuda) | LR: 1.7592186044416019e-06


 52%|█████▎    | 105/200 [10:48<09:09,  5.79s/it]

Epoch: 105/200 | Train Loss: 1.0844923053809907e-05 | Time: 5.34s (cuda) | LR: 1.7592186044416019e-06


 53%|█████▎    | 106/200 [10:54<09:26,  6.03s/it]

Epoch: 106/200 | Train Loss: 1.0717850273067597e-05 | Time: 6.58s (cuda) | LR: 1.4073748835532816e-06


 54%|█████▎    | 107/200 [11:00<09:02,  5.84s/it]

Epoch: 107/200 | Train Loss: 9.453588063479401e-06 | Time: 5.39s (cuda) | LR: 1.4073748835532816e-06


 54%|█████▍    | 108/200 [11:06<09:13,  6.02s/it]

Epoch: 108/200 | Train Loss: 9.031826266436838e-06 | Time: 6.43s (cuda) | LR: 1.4073748835532816e-06


 55%|█████▍    | 109/200 [11:12<08:57,  5.90s/it]

Epoch: 109/200 | Train Loss: 9.052893801708706e-06 | Time: 5.64s (cuda) | LR: 1.1258999068426254e-06


 55%|█████▌    | 110/200 [11:17<08:44,  5.82s/it]

Epoch: 110/200 | Train Loss: 8.17463660496287e-06 | Time: 5.63s (cuda) | LR: 1.1258999068426254e-06


 56%|█████▌    | 111/200 [11:24<08:46,  5.92s/it]

Epoch: 111/200 | Train Loss: 7.946956429805141e-06 | Time: 6.14s (cuda) | LR: 1.1258999068426254e-06


 56%|█████▌    | 112/200 [11:29<08:26,  5.75s/it]

Epoch: 112/200 | Train Loss: 7.934586392366327e-06 | Time: 5.37s (cuda) | LR: 1.1258999068426254e-06


 56%|█████▋    | 113/200 [11:35<08:42,  6.01s/it]

Epoch: 113/200 | Train Loss: 7.896715942479204e-06 | Time: 6.6s (cuda) | LR: 1.1258999068426254e-06


 57%|█████▋    | 114/200 [11:41<08:18,  5.79s/it]

Epoch: 114/200 | Train Loss: 7.831221410015132e-06 | Time: 5.29s (cuda) | LR: 1.1258999068426254e-06


 57%|█████▊    | 115/200 [11:47<08:28,  5.98s/it]

Epoch: 115/200 | Train Loss: 7.776048732921481e-06 | Time: 6.42s (cuda) | LR: 1.1258999068426254e-06


 58%|█████▊    | 116/200 [11:52<08:05,  5.78s/it]

Epoch: 116/200 | Train Loss: 7.713934792263899e-06 | Time: 5.3s (cuda) | LR: 1.1258999068426254e-06


 58%|█████▊    | 117/200 [11:59<08:16,  5.98s/it]

Epoch: 117/200 | Train Loss: 7.623189958394505e-06 | Time: 6.45s (cuda) | LR: 1.1258999068426254e-06


 59%|█████▉    | 118/200 [12:04<07:56,  5.81s/it]

Epoch: 118/200 | Train Loss: 7.567256943730172e-06 | Time: 5.39s (cuda) | LR: 1.1258999068426254e-06


 60%|█████▉    | 119/200 [12:11<08:00,  5.93s/it]

Epoch: 119/200 | Train Loss: 7.487641596526373e-06 | Time: 6.21s (cuda) | LR: 1.1258999068426254e-06


 60%|██████    | 120/200 [12:16<07:44,  5.81s/it]

Epoch: 120/200 | Train Loss: 7.426900538121117e-06 | Time: 5.51s (cuda) | LR: 1.1258999068426254e-06


 60%|██████    | 121/200 [12:22<07:35,  5.77s/it]

Epoch: 121/200 | Train Loss: 7.334077963605523e-06 | Time: 5.69s (cuda) | LR: 9.007199254741003e-07


 61%|██████    | 122/200 [12:28<07:33,  5.81s/it]

Epoch: 122/200 | Train Loss: 6.676857537968317e-06 | Time: 5.9s (cuda) | LR: 9.007199254741003e-07


 62%|██████▏   | 123/200 [12:33<07:18,  5.70s/it]

Epoch: 123/200 | Train Loss: 6.523503088828875e-06 | Time: 5.44s (cuda) | LR: 9.007199254741003e-07


 62%|██████▏   | 124/200 [12:39<07:28,  5.90s/it]

Epoch: 124/200 | Train Loss: 6.481851869466482e-06 | Time: 6.35s (cuda) | LR: 9.007199254741003e-07


 62%|██████▎   | 125/200 [12:45<07:12,  5.77s/it]

Epoch: 125/200 | Train Loss: 6.4576524891890585e-06 | Time: 5.47s (cuda) | LR: 9.007199254741003e-07


 63%|██████▎   | 126/200 [12:51<07:20,  5.96s/it]

Epoch: 126/200 | Train Loss: 6.410508831322659e-06 | Time: 6.4s (cuda) | LR: 9.007199254741003e-07


 64%|██████▎   | 127/200 [12:57<07:00,  5.76s/it]

Epoch: 127/200 | Train Loss: 6.364515229506651e-06 | Time: 5.3s (cuda) | LR: 9.007199254741003e-07


 64%|██████▍   | 128/200 [13:03<07:11,  5.99s/it]

Epoch: 128/200 | Train Loss: 6.3165371102513745e-06 | Time: 6.54s (cuda) | LR: 9.007199254741003e-07


 64%|██████▍   | 129/200 [13:09<06:52,  5.81s/it]

Epoch: 129/200 | Train Loss: 6.265262072702171e-06 | Time: 5.39s (cuda) | LR: 9.007199254741003e-07


 65%|██████▌   | 130/200 [13:15<07:01,  6.01s/it]

Epoch: 130/200 | Train Loss: 6.213703272806015e-06 | Time: 6.48s (cuda) | LR: 9.007199254741003e-07


 66%|██████▌   | 131/200 [13:20<06:41,  5.82s/it]

Epoch: 131/200 | Train Loss: 6.170053893583827e-06 | Time: 5.35s (cuda) | LR: 9.007199254741003e-07


 66%|██████▌   | 132/200 [13:26<06:30,  5.75s/it]

Epoch: 132/200 | Train Loss: 6.105942702561151e-06 | Time: 5.58s (cuda) | LR: 9.007199254741003e-07


 66%|██████▋   | 133/200 [13:32<06:29,  5.82s/it]

Epoch: 133/200 | Train Loss: 6.05967034061905e-06 | Time: 5.98s (cuda) | LR: 9.007199254741003e-07


 67%|██████▋   | 134/200 [13:37<06:17,  5.72s/it]

Epoch: 134/200 | Train Loss: 6.016886345605599e-06 | Time: 5.5s (cuda) | LR: 9.007199254741003e-07


 68%|██████▊   | 135/200 [13:44<06:29,  5.99s/it]

Epoch: 135/200 | Train Loss: 5.964640877209604e-06 | Time: 6.6s (cuda) | LR: 9.007199254741003e-07


 68%|██████▊   | 136/200 [13:49<06:10,  5.79s/it]

Epoch: 136/200 | Train Loss: 5.914996563660679e-06 | Time: 5.31s (cuda) | LR: 7.205759403792803e-07


 68%|██████▊   | 137/200 [13:56<06:13,  5.94s/it]

Epoch: 137/200 | Train Loss: 5.435798811959103e-06 | Time: 6.28s (cuda) | LR: 7.205759403792803e-07


 69%|██████▉   | 138/200 [14:01<05:54,  5.72s/it]

Epoch: 138/200 | Train Loss: 5.33256343260291e-06 | Time: 5.23s (cuda) | LR: 7.205759403792803e-07


 70%|██████▉   | 139/200 [14:07<06:04,  5.97s/it]

Epoch: 139/200 | Train Loss: 5.308048457663972e-06 | Time: 6.54s (cuda) | LR: 7.205759403792803e-07


 70%|███████   | 140/200 [14:13<05:46,  5.77s/it]

Epoch: 140/200 | Train Loss: 5.2884201977576595e-06 | Time: 5.31s (cuda) | LR: 7.205759403792803e-07


 70%|███████   | 141/200 [14:19<05:51,  5.96s/it]

Epoch: 141/200 | Train Loss: 5.255395535641583e-06 | Time: 6.39s (cuda) | LR: 7.205759403792803e-07


 71%|███████   | 142/200 [14:24<05:33,  5.75s/it]

Epoch: 142/200 | Train Loss: 5.218447768129408e-06 | Time: 5.25s (cuda) | LR: 7.205759403792803e-07


 72%|███████▏  | 143/200 [14:30<05:23,  5.68s/it]

Epoch: 143/200 | Train Loss: 5.190965566725936e-06 | Time: 5.51s (cuda) | LR: 7.205759403792803e-07


 72%|███████▏  | 144/200 [14:36<05:27,  5.86s/it]

Epoch: 144/200 | Train Loss: 5.1592123782029375e-06 | Time: 6.27s (cuda) | LR: 7.205759403792803e-07


 72%|███████▎  | 145/200 [14:42<05:13,  5.70s/it]

Epoch: 145/200 | Train Loss: 5.115273779665586e-06 | Time: 5.32s (cuda) | LR: 7.205759403792803e-07


 73%|███████▎  | 146/200 [14:48<05:18,  5.90s/it]

Epoch: 146/200 | Train Loss: 5.0816015573218465e-06 | Time: 6.36s (cuda) | LR: 7.205759403792803e-07


 74%|███████▎  | 147/200 [14:53<05:03,  5.73s/it]

Epoch: 147/200 | Train Loss: 5.04785384691786e-06 | Time: 5.35s (cuda) | LR: 7.205759403792803e-07


 74%|███████▍  | 148/200 [15:00<05:11,  6.00s/it]

Epoch: 148/200 | Train Loss: 5.012972906115465e-06 | Time: 6.62s (cuda) | LR: 7.205759403792803e-07


 74%|███████▍  | 149/200 [15:05<04:57,  5.84s/it]

Epoch: 149/200 | Train Loss: 4.97570181323681e-06 | Time: 5.45s (cuda) | LR: 7.205759403792803e-07


 75%|███████▌  | 150/200 [15:12<04:59,  6.00s/it]

Epoch: 150/200 | Train Loss: 4.943174644722603e-06 | Time: 6.37s (cuda) | LR: 7.205759403792803e-07


 76%|███████▌  | 151/200 [15:17<04:43,  5.79s/it]

Epoch: 151/200 | Train Loss: 4.9192340156878345e-06 | Time: 5.32s (cuda) | LR: 5.764607523034243e-07


 76%|███████▌  | 152/200 [15:23<04:46,  5.98s/it]

Epoch: 152/200 | Train Loss: 4.558860382530838e-06 | Time: 6.4s (cuda) | LR: 5.764607523034243e-07


 76%|███████▋  | 153/200 [15:29<04:30,  5.75s/it]

Epoch: 153/200 | Train Loss: 4.487210389925167e-06 | Time: 5.21s (cuda) | LR: 5.764607523034243e-07


 77%|███████▋  | 154/200 [15:34<04:25,  5.78s/it]

Epoch: 154/200 | Train Loss: 4.468001861823723e-06 | Time: 5.85s (cuda) | LR: 5.764607523034243e-07


 78%|███████▊  | 155/200 [15:40<04:21,  5.81s/it]

Epoch: 155/200 | Train Loss: 4.454514964891132e-06 | Time: 5.9s (cuda) | LR: 5.764607523034243e-07


 78%|███████▊  | 156/200 [15:46<04:07,  5.63s/it]

Epoch: 156/200 | Train Loss: 4.433964022609871e-06 | Time: 5.21s (cuda) | LR: 5.764607523034243e-07


 78%|███████▊  | 157/200 [15:52<04:12,  5.87s/it]

Epoch: 157/200 | Train Loss: 4.40819621871924e-06 | Time: 6.43s (cuda) | LR: 5.764607523034243e-07


 79%|███████▉  | 158/200 [15:57<03:58,  5.68s/it]

Epoch: 158/200 | Train Loss: 4.3815161916427314e-06 | Time: 5.23s (cuda) | LR: 5.764607523034243e-07


 80%|███████▉  | 159/200 [16:04<04:00,  5.87s/it]

Epoch: 159/200 | Train Loss: 4.356323643150972e-06 | Time: 6.31s (cuda) | LR: 5.764607523034243e-07


 80%|████████  | 160/200 [16:09<03:49,  5.75s/it]

Epoch: 160/200 | Train Loss: 4.336805432103574e-06 | Time: 5.45s (cuda) | LR: 5.764607523034243e-07


 80%|████████  | 161/200 [16:15<03:51,  5.94s/it]

Epoch: 161/200 | Train Loss: 4.311411885282723e-06 | Time: 6.38s (cuda) | LR: 5.764607523034243e-07


 81%|████████  | 162/200 [16:21<03:38,  5.74s/it]

Epoch: 162/200 | Train Loss: 4.282331701688236e-06 | Time: 5.27s (cuda) | LR: 5.764607523034243e-07


 82%|████████▏ | 163/200 [16:27<03:33,  5.78s/it]

Epoch: 163/200 | Train Loss: 4.259857632860076e-06 | Time: 5.88s (cuda) | LR: 5.764607523034243e-07


 82%|████████▏ | 164/200 [16:32<03:27,  5.76s/it]

Epoch: 164/200 | Train Loss: 4.2304391172365285e-06 | Time: 5.7s (cuda) | LR: 5.764607523034243e-07


 82%|████████▎ | 165/200 [16:38<03:17,  5.64s/it]

Epoch: 165/200 | Train Loss: 4.2101564758922905e-06 | Time: 5.38s (cuda) | LR: 5.764607523034243e-07


 83%|████████▎ | 166/200 [16:44<03:18,  5.85s/it]

Epoch: 166/200 | Train Loss: 4.18280342273647e-06 | Time: 6.32s (cuda) | LR: 4.6116860184273944e-07


 84%|████████▎ | 167/200 [16:49<03:07,  5.69s/it]

Epoch: 167/200 | Train Loss: 3.91786488762591e-06 | Time: 5.31s (cuda) | LR: 4.6116860184273944e-07


 84%|████████▍ | 168/200 [16:56<03:09,  5.92s/it]

Epoch: 168/200 | Train Loss: 3.8706775740138255e-06 | Time: 6.47s (cuda) | LR: 4.6116860184273944e-07


 84%|████████▍ | 169/200 [17:01<02:56,  5.70s/it]

Epoch: 169/200 | Train Loss: 3.86031251764507e-06 | Time: 5.2s (cuda) | LR: 4.6116860184273944e-07


 85%|████████▌ | 170/200 [17:07<02:57,  5.93s/it]

Epoch: 170/200 | Train Loss: 3.844495950033888e-06 | Time: 6.45s (cuda) | LR: 4.6116860184273944e-07


 86%|████████▌ | 171/200 [17:13<02:45,  5.71s/it]

Epoch: 171/200 | Train Loss: 3.834471954178298e-06 | Time: 5.21s (cuda) | LR: 4.6116860184273944e-07


 86%|████████▌ | 172/200 [17:19<02:45,  5.92s/it]

Epoch: 172/200 | Train Loss: 3.8144112295412924e-06 | Time: 6.4s (cuda) | LR: 4.6116860184273944e-07


 86%|████████▋ | 173/200 [17:24<02:33,  5.70s/it]

Epoch: 173/200 | Train Loss: 3.792488314502407e-06 | Time: 5.18s (cuda) | LR: 4.6116860184273944e-07


 87%|████████▋ | 174/200 [17:30<02:27,  5.68s/it]

Epoch: 174/200 | Train Loss: 3.7766772038594354e-06 | Time: 5.62s (cuda) | LR: 4.6116860184273944e-07


 88%|████████▊ | 175/200 [17:36<02:24,  5.77s/it]

Epoch: 175/200 | Train Loss: 3.755343868760974e-06 | Time: 5.99s (cuda) | LR: 4.6116860184273944e-07


 88%|████████▊ | 176/200 [17:41<02:15,  5.66s/it]

Epoch: 176/200 | Train Loss: 3.743026354641188e-06 | Time: 5.4s (cuda) | LR: 4.6116860184273944e-07


 88%|████████▊ | 177/200 [17:47<02:14,  5.84s/it]

Epoch: 177/200 | Train Loss: 3.7256208997860085e-06 | Time: 6.25s (cuda) | LR: 4.6116860184273944e-07


 89%|████████▉ | 178/200 [17:53<02:05,  5.68s/it]

Epoch: 178/200 | Train Loss: 3.703330094140256e-06 | Time: 5.32s (cuda) | LR: 4.6116860184273944e-07


 90%|████████▉ | 179/200 [17:59<02:03,  5.88s/it]

Epoch: 179/200 | Train Loss: 3.6845206068392145e-06 | Time: 6.35s (cuda) | LR: 4.6116860184273944e-07


 90%|█████████ | 180/200 [18:04<01:53,  5.69s/it]

Epoch: 180/200 | Train Loss: 3.6678052310890052e-06 | Time: 5.26s (cuda) | LR: 4.6116860184273944e-07


 90%|█████████ | 181/200 [18:11<01:52,  5.94s/it]

Epoch: 181/200 | Train Loss: 3.6511689813778503e-06 | Time: 6.5s (cuda) | LR: 3.689348814741916e-07


 91%|█████████ | 182/200 [18:16<01:43,  5.73s/it]

Epoch: 182/200 | Train Loss: 3.4483091440051794e-06 | Time: 5.24s (cuda) | LR: 3.689348814741916e-07


 92%|█████████▏| 183/200 [18:22<01:39,  5.84s/it]

Epoch: 183/200 | Train Loss: 3.4178287933173124e-06 | Time: 6.1s (cuda) | LR: 3.689348814741916e-07


 92%|█████████▏| 184/200 [18:28<01:33,  5.83s/it]

Epoch: 184/200 | Train Loss: 3.40724704983586e-06 | Time: 5.8s (cuda) | LR: 3.689348814741916e-07


 92%|█████████▎| 185/200 [18:34<01:26,  5.76s/it]

Epoch: 185/200 | Train Loss: 3.394984560145531e-06 | Time: 5.59s (cuda) | LR: 3.689348814741916e-07


 93%|█████████▎| 186/200 [18:40<01:23,  5.95s/it]

Epoch: 186/200 | Train Loss: 3.3851774787763134e-06 | Time: 6.4s (cuda) | LR: 3.689348814741916e-07


 94%|█████████▎| 187/200 [18:45<01:14,  5.74s/it]

Epoch: 187/200 | Train Loss: 3.372001174284378e-06 | Time: 5.24s (cuda) | LR: 3.689348814741916e-07


 94%|█████████▍| 188/200 [18:52<01:11,  5.93s/it]

Epoch: 188/200 | Train Loss: 3.3582141441002022e-06 | Time: 6.39s (cuda) | LR: 3.689348814741916e-07


 94%|█████████▍| 189/200 [18:57<01:03,  5.76s/it]

Epoch: 189/200 | Train Loss: 3.345701088619535e-06 | Time: 5.35s (cuda) | LR: 3.689348814741916e-07


 95%|█████████▌| 190/200 [19:04<01:00,  6.00s/it]

Epoch: 190/200 | Train Loss: 3.334504071972333e-06 | Time: 6.57s (cuda) | LR: 3.689348814741916e-07


 96%|█████████▌| 191/200 [19:09<00:53,  5.91s/it]

Epoch: 191/200 | Train Loss: 3.317104528832715e-06 | Time: 5.68s (cuda) | LR: 3.689348814741916e-07


 96%|█████████▌| 192/200 [19:16<00:48,  6.07s/it]

Epoch: 192/200 | Train Loss: 3.307384304207517e-06 | Time: 6.46s (cuda) | LR: 3.689348814741916e-07


 96%|█████████▋| 193/200 [19:21<00:40,  5.85s/it]

Epoch: 193/200 | Train Loss: 3.29325303027872e-06 | Time: 5.32s (cuda) | LR: 3.689348814741916e-07


 97%|█████████▋| 194/200 [19:27<00:36,  6.01s/it]

Epoch: 194/200 | Train Loss: 3.279236580056022e-06 | Time: 6.39s (cuda) | LR: 3.689348814741916e-07


 98%|█████████▊| 195/200 [19:33<00:29,  5.81s/it]

Epoch: 195/200 | Train Loss: 3.2624909636069788e-06 | Time: 5.35s (cuda) | LR: 3.689348814741916e-07


 98%|█████████▊| 196/200 [19:39<00:23,  5.86s/it]

Epoch: 196/200 | Train Loss: 3.25239852827508e-06 | Time: 5.97s (cuda) | LR: 2.9514790517935326e-07


 98%|█████████▊| 197/200 [19:45<00:17,  5.94s/it]

Epoch: 197/200 | Train Loss: 3.0981609597802162e-06 | Time: 6.14s (cuda) | LR: 2.9514790517935326e-07


 99%|█████████▉| 198/200 [19:50<00:11,  5.76s/it]

Epoch: 198/200 | Train Loss: 3.0763708309677895e-06 | Time: 5.32s (cuda) | LR: 2.9514790517935326e-07


100%|█████████▉| 199/200 [19:57<00:05,  6.00s/it]

Epoch: 199/200 | Train Loss: 3.0682040232932195e-06 | Time: 6.56s (cuda) | LR: 2.9514790517935326e-07


100%|██████████| 200/200 [20:02<00:00,  6.01s/it]

Epoch: 200/200 | Train Loss: 3.056057266803691e-06 | Time: 5.36s (cuda) | LR: 2.9514790517935326e-07

Epoch with Least Loss: 200 | Loss: 3.0560573e-06 






In [22]:
# Initialize prediction lists
# Initialize prediction lists
prediction_list = [[] for _ in range(1)]
total_vars=1
# Inference loop
model_mean = MyResidualSirenNet(obj).to(device)
state_dict = torch.load(os.path.join(outpath, '200_std.pth'))['model_state_dict']
model_mean.load_state_dict(state_dict)
with torch.no_grad():
    for i in range(0, torch_coords.shape[0], group_size):
        coords = torch_coords[i:min(i + group_size, torch_coords.shape[0])].type(torch.float32).to(device)
        vals = model_mean(coords)
        vals = vals.to('cpu')

        for j in range(total_vars):
            prediction_list[j].append(vals[:, j])

# Extract and concatenate predictions
extracted_list = [[] for _ in range(1)]
for i in range(len(prediction_list[0])):
    for j in range(1):
        el = prediction_list[j][i].detach().numpy()
        extracted_list[j].append(el)

for j in range(1):
    extracted_list[j] = np.concatenate(extracted_list[j], dtype='float32')

# Final prediction (normalized)
n_predictions_stds = np.array(extracted_list).T
print(n_predictions_stds.shape)
# Compute PSNR
#findMultiVariatePSNR(var_name[0], total_vars, n_val[:,0], n_predictions_means[:,0])
print("std",compute_PSNR(n_val[:,1],n_predictions_stds[:,0]))
# Compute RMSE
rmse = compute_rmse(n_val[:,1], n_predictions_stds[:,0])
print("RMSE:", rmse)

(262144, 1)
std 61.882343845798715
RMSE: 0.001610322285992644


In [24]:
n_predictions = np.concatenate([n_predictions_means, n_predictions_stds], axis=1)
# !rm -rf /kaggle/working/*

In [None]:
print(os.path.getsize('/kaggle/working/models/train_3d_data_200ep_6rb_320n_512bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval15.pth') / (1024 ** 2), 'MB')

4.731752395629883 MB


In [28]:
# # vti saving path
vti_path = args.vti_path
if not os.path.exists(vti_path):
    os.makedirs(vti_path)
# vti name
vti_name = args.vti_name
isMaskPresent = False
mask_arr = []
total_vars=2
makeVTI(data,real_data, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name)

Vti File written successfully at ./data/predicted.vti
