In [1]:
!pip install vtk

Collecting vtk
  Downloading vtk-9.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.5 kB)
Downloading vtk-9.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (112.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.1/112.1 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vtk
Successfully installed vtk-9.5.0


In [None]:
# -*- coding: utf-8 -*-
"""parameter_learning_main.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1O6m5CVsAt85AmsY18hWW8Eoj_JgiLqQ8
"""

import numpy as np
import torch
from torch import nn, optim
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import seaborn as sns
import matplotlib.pyplot as plt
import vtk
from vtk import *
from vtk.util.numpy_support import vtk_to_numpy
import random
import os
import sys
import time


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Device running:', device)

class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.in_features = in_features
        # if enable_dropout:
        #     if not self.is_first:
        #         self.dropout = nn.Dropout(dropout_prob)
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)


    def forward(self, x):
        x = self.linear(x)
        # if self.enable_dropout:
        #     if not self.is_first:
        #         x = self.dropout(x)
        return torch.sin(self.omega_0 * x)

class ResidualSineLayer(nn.Module):
    def __init__(self, features, bias=True, ave_first=False, ave_second=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        # self.enable_dropout = enable_dropout
        # self.dropout_prob = dropout_prob
        self.features = features
        # if enable_dropout:
        #     self.dropout_1 = nn.Dropout(dropout_prob)
        self.linear_1 = nn.Linear(features, features, bias=bias)
        self.linear_2 = nn.Linear(features, features, bias=bias)
        self.weight_1 = .5 if ave_first else 1
        self.weight_2 = .5 if ave_second else 1

        self.init_weights()


    def init_weights(self):
        with torch.no_grad():
            self.linear_1.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)
            self.linear_2.weight.uniform_(-np.sqrt(6 / self.features) / self.omega_0,
                                           np.sqrt(6 / self.features) / self.omega_0)

    def forward(self, input):
        linear_1 = self.linear_1(self.weight_1*input)
        # if self.enable_dropout:
        #     linear_1 = self.dropout_1(linear_1)
        sine_1 = torch.sin(self.omega_0 * linear_1)
        sine_2 = torch.sin(self.omega_0 * self.linear_2(sine_1))
        return self.weight_2*(input+sine_2)

class MyResidualSirenNet(nn.Module):
    def __init__(self, obj):
        super(MyResidualSirenNet, self).__init__()
        # self.enable_dropout = obj['enable_dropout']
        # self.dropout_prob = obj['dropout_prob']
        self.Omega_0=30
        self.n_layers = obj['n_layers']
        self.input_dim = obj['dim']
        self.output_dim = obj['total_vars']
        self.neurons_per_layer = obj['n_neurons']
        self.layers = [self.input_dim]
        for i in range(self.n_layers-1):
            self.layers.append(self.neurons_per_layer)
        self.layers.append(self.output_dim)
        self.net_layers = nn.ModuleList()
        for idx in np.arange(self.n_layers):
            layer_in = self.layers[idx]
            layer_out = self.layers[idx+1]
            ## if not the final layer
            if idx != self.n_layers-1:
                ## if first layer
                if idx==0:
                    self.net_layers.append(SineLayer(layer_in,layer_out,bias=True,is_first=idx==0))
                ## if an intermdeiate layer
                else:
                    self.net_layers.append(ResidualSineLayer(layer_in,bias=True,ave_first=idx>1,ave_second=idx==(self.n_layers-2)))
            ## if final layer
            else:
                final_linear = nn.Linear(layer_in,layer_out)
                ## initialize weights for the final layer
                with torch.no_grad():
                    final_linear.weight.uniform_(-np.sqrt(6 / (layer_in)) / self.Omega_0, np.sqrt(6 / (layer_in)) / self.Omega_0)
                self.net_layers.append(final_linear)

    def forward(self,x):
        for net_layer in self.net_layers:
            x = net_layer(x)
        return x

def size_of_network(n_layers, n_neurons, d_in, d_out, is_residual = True):
    # Adding input layer
    layers = [d_in]
    # layers = [3]

    # Adding hidden layers
    layers.extend([n_neurons]*n_layers)
    # layers = [3, 5, 5, 5]

    # Adding output layer
    layers.append(d_out)
    # layers = [3, 5, 5, 5, 1]

    # Number of steps
    n_layers = len(layers)-1
    # n_layers = 5 - 1 = 4

    n_params = 0

    # np.arange(4) = [0, 1, 2, 3]
    for ndx in np.arange(n_layers):

        # number of neurons in below layer
        layer_in = layers[ndx]

        # number of neurons in above layer
        layer_out = layers[ndx+1]

        # max number of neurons in both the layer
        og_layer_in = max(layer_in,layer_out)

        # if lower layer is the input layer
        # or the upper layer is the output layer
        if ndx==0 or ndx==(n_layers-1):
            # Adding weight corresponding to every neuron for every input neuron
            # Adding bias for every neuron in the upper layer
            n_params += ((layer_in+1)*layer_out)

        else:

            # If the layer is residual then proceed as follows as there will be more weights if residual layer is included
            if is_residual:
                # doubt in the following two lines
                n_params += (layer_in*og_layer_in)+og_layer_in
                n_params += (og_layer_in*layer_out)+layer_out

            # if the layer is non residual then simply add number of weights and biases as follows
            else:
                n_params += ((layer_in+1)*layer_out)
            #
        #
    #

    return n_params

def compute_PSNR(arrgt,arr_recon):
    diff = arrgt - arr_recon
    sqd_max_diff = (np.max(arrgt)-np.min(arrgt))**2
    snr = 10*np.log10(sqd_max_diff/np.mean(diff**2))
    return snr

def srs(numOfPoints, valid_pts, percentage, isMaskPresent, mask_array):

    # getting total number of sampled points
    numberOfSampledPoints = int((valid_pts/100) * percentage)

    # storing corner indices in indices variable
    indices = set()

    # As long as we don't get the required amount of sample points keep finding the random numbers
    while(len(indices) < numberOfSampledPoints):
        rp = random.randint(0, numOfPoints-1)
        if isMaskPresent and mask_array[rp] == 0:
            continue
        indices.add(rp)

    # return indices
    return indices

def findMultiVariatePSNR(var_name, total_vars, actual, pred):
    # print('Printing PSNR')
    tot = 0
    psnr_list = []
    for j in range(total_vars):
        psnr = compute_PSNR(actual[:,j], pred[:,j])
        psnr_list.append(psnr)
        tot += psnr
        print(var_name[j], ' PSNR:', psnr)
    avg_psnr = tot/total_vars
    print('\nAverage psnr : ', avg_psnr)
     #this function is calculating the psnr of final epoch (or whenever it is called) of each variable and then averaging it
     #Thus individual epochs psnr is not calculated

    return psnr_list, avg_psnr

def compute_rmse(actual, predicted):
    mse = np.mean((actual - predicted) ** 2)
    return np.sqrt(mse)

def denormalizeValue(total_vars, to, ref):
    to_arr = np.array(to)
    for i in range(total_vars):
        min_data = np.min(ref[:, i])
        max_data = np.max(ref[:, i])
        to_arr[:, i] = (((to[:, i] * 0.5) + 0.5) * (max_data - min_data)) + min_data
    return to_arr

def makeVTI(data, val, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name, normalizedVersion = False):
    nn_predictions = denormalizeValue(total_vars, n_predictions, val) if not normalizedVersion else n_predictions
    writer = vtkXMLImageDataWriter()
    writer.SetFileName(vti_path + vti_name)
    img = vtkImageData()
    img.CopyStructure(data)
    if not isMaskPresent:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            arr = vtkFloatArray()
            for j in range(n_pts):
                arr.InsertNextValue(temp[j])
            arr.SetName(f)
            img.GetPointData().AddArray(arr)
        # print(img)
        writer.SetInputData(img)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')
    else:
        for i in range(total_vars):
            f = var_name[i]
            temp = nn_predictions[:, i]
            idx = 0
            arr = vtkFloatArray()
            for j in range(n_pts):
                if(mask_arr[j] == 1):
                    arr.InsertNextValue(temp[idx])
                    idx += 1
                else:
                    arr.InsertNextValue(0.0)
            arr.SetName('p_' + f)
            data.GetPointData().AddArray(arr)
        # print(data)
        writer.SetInputData(data)
        writer.Write()
        print(f'Vti File written successfully at {vti_path}{vti_name}')

def getImageData(actual_img, val, n_pts, var_name, isMaskPresent, mask_arr):
    img = vtkImageData()
    img.CopyStructure(actual_img)
    # if isMaskPresent:
    #     img.DeepCopy(actual_img)
    # img.SetDimensions(dim)
    # img.SetOrigin(actual_img.GetOrigin())
    # img.SetSpacing(actual_img.GetSpacing())
    if not isMaskPresent:
        f = var_name
        data = val
        arr = vtkFloatArray()
        for j in range(n_pts):
            arr.InsertNextValue(data[j])
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    else:
        f = var_name
        data = val
        idx = 0
        arr = vtkFloatArray()
        for j in range(n_pts):
            if(mask_arr[j] == 1):
                arr.InsertNextValue(data[idx])
                idx += 1
            else:
                arr.InsertNextValue(0.0)
        arr.SetName(f)
        img.GetPointData().SetScalars(arr)
    return img

from argparse import Namespace

# Parameters (simulating argparse in a Jupyter Notebook)
args = Namespace(
    n_neurons=150,
    n_layers=6,
    epochs=200,  # Required argument: Set the number of epochs
    batchsize=2048,
    lr=0.00005,
    no_decay=False,
    decay_rate=0.8,
    decay_at_interval=True,
    decay_interval=15,
    datapath='/content/new_gmm_isabel_week_6.vti',  # Required: Set the path to your data
    outpath='./models/',
    exp_path='../logs/',
    modified_data_path='./data/',
    dataset_name='ema',  # Required: Set the dataset name
    vti_name='new_gmm_isabel_ema_week_6_predict.vti',  # Required: Name of the dataset
    vti_path='/data/'
)

print(args, end='\n\n')

# Assigning parameters to variables
LR = args.lr
BATCH_SIZE = args.batchsize
decay_rate = args.decay_rate
decay_at_equal_interval = args.decay_at_interval

decay = not args.no_decay
MAX_EPOCH = args.epochs

n_neurons = args.n_neurons
n_layers = args.n_layers + 2
decay_interval = args.decay_interval
outpath = args.outpath
exp_path = args.exp_path
datapath = args.datapath
modified_data_path = args.modified_data_path
dataset_name = args.dataset_name
vti_name = args.vti_name
vti_path = args.vti_path

# Displaying the final configuration
print(f"Learning Rate: {LR}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Decay Rate: {decay_rate}")
print(f"Max Epochs: {MAX_EPOCH}")
print(f"Number of Neurons per Layer: {n_neurons}")
print(f"Number of Layers (including input/output): {n_layers}")
print(f"Data Path: {datapath}")
print(f"Output Path: {outpath}")
print(f"Dataset Name: {dataset_name}")
print(f"Vti Name: {vti_name}")

# Variable Initialization
var_name = []
total_vars = None  # Number of variables
univariate = None  # True if dataset has one variable, else False
group_size = 5000  # Group size during testing


# Constructing the log file name
log_file = (
    f'train_{dataset_name}_{n_layers-2}rb_{n_neurons}n_{BATCH_SIZE}bs_'
    f'{LR}lr_{decay}decay_{decay_rate}dr_'
    f'{"decayingAtInterval" + str(decay_interval) if decay_at_equal_interval else "decayingWhenLossIncr"}'
)

print(log_file)

n_pts = None  # Number of points in the dataset
n_dim = None  # Dimensionality of the data
dim = None  # Other dimension-specific information

print("Decay:", decay)
print(f'Extracting variables from path: {datapath}', end="\n\n")

# Placeholder for data
data_array = []
scalar_data = None

# # Reading values from .vti files
# reader = vtk.vtkXMLImageDataReader()
# reader.SetFileName(datapath)
# reader.Update()

# data = reader.GetOutput()
# scalar_data = data
# pdata = data.GetPointData()
# n_pts = data.GetNumberOfPoints()
# dim = data.GetDimensions()
# n_dim = len(dim)
# total_arr = pdata.GetNumberOfArrays()

# print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

# mask_arr = []
# valid_pts = 0
# var_name = []
# data_array = []

# # Extracting data from the .vti file
# for i in range(total_arr):
#     a_name = pdata.GetArrayName(i)
#     if a_name in ['vtkValidPointMask', 'Swirl']:
#         continue

#     cur_arr = pdata.GetArray(a_name)
#     n_components = cur_arr.GetNumberOfComponents()

#     if n_components == 1:
#         var_name.append(a_name)
#         data_array.append(vtk_to_numpy(cur_arr))
#     else:
#         component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
#         var_name.extend(component_names)
#         for c in range(n_components):
#             c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
#             data_array.append(np.array(c_data))

# valid_pts = n_pts  # Assume all points are valid for simplicity
# total_vars = len(var_name)
# univariate = total_vars == 1

# # Prepare numpy arrays for coordinates and variable values
# cord = np.zeros((valid_pts, n_dim))
# val = np.zeros((valid_pts, total_vars))

# # Store data in numpy arrays
# for i in range(n_pts):
#     pt = scalar_data.GetPoint(i)
#     cord[i, :] = pt
#     val[i, :] = [arr[i] for arr in data_array]

# # Display final information
# print("Total Variables:", total_vars)
# print("Univariate:", univariate)
# print("Coordinates Shape:", cord.shape)
# print("Values Shape:", val.shape)

# Reading values from .vti files
reader = vtk.vtkXMLImageDataReader()
reader.SetFileName(datapath)
reader.Update()

data = reader.GetOutput()
scalar_data = data
pdata = data.GetPointData()
n_pts = data.GetNumberOfPoints()
dim = data.GetDimensions()
n_dim = len(dim)
total_arr = pdata.GetNumberOfArrays()

print("n_pts:", n_pts, "dim:", dim, "n_dim:", n_dim, "total_arr:", total_arr)

var_name = []
data_array = []

# Extracting data from the .vti file
for i in range(total_arr):
    a_name = pdata.GetArrayName(i)

    cur_arr = pdata.GetArray(a_name)
    n_components = cur_arr.GetNumberOfComponents()

    if n_components == 1:
        var_name.append(a_name)
        data_array.append(vtk_to_numpy(cur_arr))
    else:
        component_names = [f"{a_name}_{c}" for c in ['x', 'y', 'z'][:n_components]]
        var_name.extend(component_names)
        for c in range(n_components):
            c_data = [cur_arr.GetComponent(j, c) for j in range(n_pts)]
            data_array.append(np.array(c_data))

total_vars = len(var_name)
univariate = total_vars == 1

# Prepare numpy arrays for coordinates and variable values
cord = np.zeros((n_pts, n_dim))
val = np.zeros((n_pts, total_vars))

# Store data in numpy arrays
for i in range(n_pts):
    pt = scalar_data.GetPoint(i)
    cord[i, :] = pt
    val[i, :] = [arr[i] for arr in data_array]

# Display final information
print("Total Variables:", total_vars)
print("Univariate:", univariate)
print("Coordinates Shape:", cord.shape)
print("Values Shape:", val.shape)

# # Ensure modified data path exists
# if not os.path.exists(modified_data_path):
#     os.mkdir(modified_data_path)

# Save raw coordinates and values
# np.save(f'{modified_data_path}cord.npy', cord)
# np.save(f'{modified_data_path}val.npy', val)

# # Create copies of non-normalized data
# nn_cord = cord.copy()
# nn_val = val.copy()

# === Separate Normalization for Values ===
# We assume the variable order is:
#   - Means: indices 0,1,2
#   - Std Devs: indices 3,4,5
#   - Weights: indices 6,7,8

# # We'll store normalization parameters so that we can invert normalization later.
# norm_params = {}
# epsilon = 1e-8  # to avoid log(0)

# # Normalize Means to [-1,1] using min–max normalization
# for i in range(3):
#     min_val = np.min(val[:, i])
#     max_val = np.max(val[:, i])
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((val[:, i] - min_val) / (max_val - min_val) - 0.5)

# # Normalize Std Devs: first take log, then min–max to [-1,1]
# for i in range(3, 6):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# # Normalize Weights: take log, then min–max to [-1,1]
# for i in range(6, 9):
#     log_vals = np.log(val[:, i] + epsilon)
#     min_val = np.min(log_vals)
#     max_val = np.max(log_vals)
#     norm_params[var_name[i]] = (min_val, max_val)
#     val[:, i] = 2.0 * ((log_vals - min_val) / (max_val - min_val) - 0.5)

# norm_params = {}
real_data=val.copy()
# Normalize values between -1 and 1
for i in range(total_vars):
    min_data = np.min(val[:, i])
    max_data = np.max(val[:, i])
    # norm_params[var_name[i]] = (min_data, max_data)
    val[:, i] = 2.0 * ((val[:, i] - min_data) / (max_data - min_data) - 0.5)

# Normalize Coordinates to [-1,1]
for i in range(n_dim):
    # Use (dim[i]-1] so that coordinates go from 0 to dim[i]-1.
    cord[:, i] = 2.0 * (cord[:, i] / (dim[i] - 1) - 0.5)

# # Normalize coordinates between 0 and 1
# for i in range(n_dim):
#     cord[:, i] = cord[:, i] / dim[i]


# # Save normalized values and coordinates
# np.save(f'{modified_data_path}n_cord.npy', cord)
# np.save(f'{modified_data_path}n_val.npy', val)
n_cord = cord.copy()
n_val = val.copy()

# # Reload data for verification
# n_cord = np.load(f'{modified_data_path}n_cord.npy')
# n_val = np.load(f'{modified_data_path}n_val.npy')
# cord = np.load(f'{modified_data_path}cord.npy')
# val = np.load(f'{modified_data_path}val.npy')

# Convert normalized data to PyTorch tensors
torch_coords = torch.from_numpy(n_cord)
torch_vals = torch.from_numpy(n_val)

# Display dataset details
print('Dataset Name:', dataset_name)
print('Total Variables:', total_vars)
print('Variables Name:', var_name, end="\n\n")
print('Total Points in Data:', n_pts)
print('Dimension of the Dataset:', dim)
print('Number of Dimensions:', n_dim)
print('Coordinate Tensor Shape:', torch_coords.shape)
print('Scalar Values Tensor Shape:', torch_vals.shape)

print('\n###### Data setup is complete, now starting training ######\n')

# Prepare the DataLoader
train_dataloader = DataLoader(
    TensorDataset(torch_coords, torch_vals),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    shuffle=True,
    num_workers=4
)

# Model configuration
obj = {
    'total_vars': total_vars,
    'dim': n_dim,
    'n_neurons': n_neurons,
    'n_layers': n_layers
}

# Initialize the model, optimizer, and loss function
model = MyResidualSirenNet(obj).to(device)
print(model)

optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999))
print(optimizer)

criterion = nn.MSELoss()
print(criterion)

# Training configuration summary
print('\nLearning Rate:', LR)
print('Max Epochs:', MAX_EPOCH)
print('Batch Size:', BATCH_SIZE)
print('Number of Hidden Layers:', obj['n_layers'] - 2)
print('Number of Neurons per Layer:', obj['n_neurons'])

if decay:
    print('Decay Rate:', decay_rate)
    if decay_at_equal_interval:
        print(f'Rate decays every {decay_interval} epochs.')
    else:
        print('Rate decays when the current epoch loss is greater than the previous epoch loss.')
else:
    print('No decay!')
print()

train_loss_list = []
best_epoch = -1
best_loss = 1e8
from tqdm import tqdm
# Ensure the output path exists
if not os.path.exists(outpath):
    os.makedirs(outpath)

# Training loop

# --- Setup for Loss EMA and Helpers ---
ema_decay = 0.99
grad_ema = {
    "mean": 1.0,
    "std": 1.0,
    "weight": 1.0
}

def compute_grad_norm(loss):
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    total = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total += p.grad.detach().data.norm(2).item() ** 2
    return (total ** 0.5)

# --- Main Training Loop ---
train_loss_list = []
grad_norm_list = []

best_loss = float('inf')
best_epoch = -1

for epoch in tqdm(range(MAX_EPOCH)):
    model.train()
    temp_loss_list = []
    start = time.time()

    for X_train, y_train in train_dataloader:
        X_train = X_train.float().to(device)
        y_train = y_train.float().to(device)

        optimizer.zero_grad()

        predictions = model(X_train)  # shape: [batch, 9]
        predictions = predictions.squeeze()

        # Split losses
        # pred_mean, pred_std, pred_weight = predictions[:, :3], predictions[:, 3:6], predictions[:, 6:]
        # gt_mean, gt_std, gt_weight = y_train[:, :3], y_train[:, 3:6], y_train[:, 6:]
        pred_mean=predictions[:,0]
        pred_std=predictions[:,1]
        gt_mean=y_train[:,0]
        gt_std=y_train[:,1]
        # Compute MSE losses for each group
        loss_mean = F.mse_loss(pred_mean, gt_mean)
        loss_std = F.mse_loss(pred_std, gt_std)
        #loss_weight = F.mse_loss(pred_weight, gt_weight)

        # --- Gradient Norms for each ---
        g_mean = compute_grad_norm(loss_mean)
        g_std = compute_grad_norm(loss_std)
        #g_weight = compute_grad_norm(loss_weight)

        # --- EMA Update ---
        grad_ema["mean"] = ema_decay * grad_ema["mean"] + (1 - ema_decay) * g_mean
        grad_ema["std"] = ema_decay * grad_ema["std"] + (1 - ema_decay) * g_std
        #grad_ema["weight"] = ema_decay * grad_ema["weight"] + (1 - ema_decay) * g_weight

        #total_ema = grad_ema["mean"] + grad_ema["std"] + grad_ema["weight"]
        total_ema=grad_ema["mean"] + grad_ema["std"]
        # --- Adaptive Weights ---
        w_mean = total_ema / (grad_ema["mean"])
        w_std = total_ema / (grad_ema["std"])
        #w_weight = total_ema / (3.0 * grad_ema["weight"])

        # --- Final Combined Loss ---
        #final_loss = w_mean * loss_mean + w_std * loss_std + w_weight * loss_weight
        final_loss = w_mean * loss_mean + w_std * loss_std
        final_loss.backward()

        # Gradient norm for logging
        total_grad_norm = 0.0
        for param in model.parameters():
            if param.grad is not None:
                total_grad_norm += param.grad.data.norm(2).item() ** 2
        total_grad_norm = total_grad_norm ** 0.5

        optimizer.step()
        temp_loss_list.append(final_loss.detach().cpu().item())

    # Epoch-wise processing
    epoch_loss = np.mean(temp_loss_list)

    # Learning rate decay
    if decay:
        if decay_at_equal_interval:
            if epoch >= decay_interval and epoch % decay_interval == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= decay_rate
        elif epoch > 0 and epoch_loss > train_loss_list[-1]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= decay_rate

    train_loss_list.append(epoch_loss)
    grad_norm_list.append(total_grad_norm)

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        best_epoch = epoch + 1

    end = time.time()
    print(
        f"Epoch: {epoch + 1}/{MAX_EPOCH} | Train Loss: {epoch_loss:.6f} | "
        f"Grad Norm: {round(total_grad_norm, 4)} | Time: {round(end - start, 2)}s ({device}) | "
        f"LR: {optimizer.param_groups[0]['lr']:.2e}"
    )

    # Save model checkpoint
    if (epoch + 1) % 50 == 0:
        model_name = (
            f'train_{dataset_name}_{epoch + 1}ep_{n_layers - 2}rb_{n_neurons}n_'
            f'{BATCH_SIZE}bs_{LR}lr_{decay}decay_{decay_rate}dr_'
            f'{"interval" if decay_at_equal_interval else "lossIncr"}'
        )
        torch.save(
            {"epoch": epoch + 1, "model_state_dict": model.state_dict()},
            os.path.join(outpath, f'{model_name}.pth')
        )

# Final model save
model_name = f'siren_compressor_ema'
torch.save(
    {"epoch": MAX_EPOCH, "model_state_dict": model.state_dict()},
    os.path.join(outpath, f'{model_name}.pth')
)

# --- Plot Loss and Gradient Norm ---
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss_list, label='Training Loss', color='blue')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss over Epochs')
plt.grid(True); plt.legend()

plt.subplot(1, 2, 2)
plt.plot(grad_norm_list, label='Gradient Norm (L2)', color='orange')
plt.xlabel('Epoch'); plt.ylabel('Gradient Norm'); plt.title('Gradient Norm over Epochs')
plt.grid(True); plt.legend()

plt.tight_layout()
plt.show()

# Initialize prediction lists (optional continuation)
prediction_list = [[] for _ in range(total_vars)]


# Inference loop
with torch.no_grad():
    for i in range(0, torch_coords.shape[0], group_size):
        coords = torch_coords[i:min(i + group_size, torch_coords.shape[0])].type(torch.float32).to(device)
        vals = model(coords)
        vals = vals.to('cpu')

        for j in range(total_vars):
            prediction_list[j].append(vals[:, j])

# Extract and concatenate predictions
extracted_list = [[] for _ in range(total_vars)]
for i in range(len(prediction_list[0])):
    for j in range(total_vars):
        el = prediction_list[j][i].detach().numpy()
        extracted_list[j].append(el)

for j in range(total_vars):
    extracted_list[j] = np.concatenate(extracted_list[j], dtype='float32')

# Final prediction (normalized)
n_predictions = np.array(extracted_list).T

# Compute PSNR
findMultiVariatePSNR(var_name, total_vars, n_val, n_predictions)

# Compute RMSE
print("file name:",datapath)
rmse = compute_rmse(n_val, n_predictions)
print("RMSE:", rmse)

# Initialize prediction lists
prediction_list = [[] for _ in range(total_vars)]

# Inference loop
with torch.no_grad():
    for i in range(0, torch_coords.shape[0], group_size):
        coords = torch_coords[i:min(i + group_size, torch_coords.shape[0])].type(torch.float32).to(device)
        vals = model(coords)
        vals = vals.to('cpu')

        for j in range(total_vars):
            prediction_list[j].append(vals[:, j])

# Extract and concatenate predictions
extracted_list = [[] for _ in range(total_vars)]
for i in range(len(prediction_list[0])):
    for j in range(total_vars):
        el = prediction_list[j][i].detach().numpy()
        extracted_list[j].append(el)

for j in range(total_vars):
    extracted_list[j] = np.concatenate(extracted_list[j], dtype='float32')

# Final prediction (normalized)
n_predictions = np.array(extracted_list).T

# Compute PSNR
findMultiVariatePSNR(var_name, total_vars, n_val, n_predictions)

# Compute RMSE
rmse = compute_rmse(n_val, n_predictions)
print("RMSE:", rmse)

# !rm -rf /kaggle/working/*

#print(os.path.getsize('/kaggle/working/models/train_3d_data_200ep_6rb_320n_512bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval15.pth') / (1024 ** 2), 'MB')

# # vti saving path
vti_path = args.vti_path
if not os.path.exists(vti_path):
       os.makedirs(vti_path)
## vti name
vti_name = args.vti_name
mask_arr=[]
isMaskPresent=False
makeVTI(data,real_data, n_predictions, n_pts, total_vars, var_name, dim, isMaskPresent, mask_arr, vti_path, vti_name)



Device running: cuda
Namespace(n_neurons=150, n_layers=6, epochs=200, batchsize=2048, lr=5e-05, no_decay=False, decay_rate=0.8, decay_at_interval=True, decay_interval=15, datapath='/content/new_gmm_isabel_week_6.vti', outpath='./models/', exp_path='../logs/', modified_data_path='./data/', dataset_name='ema', vti_name='new_gmm_isabel_ema_week_6_predict.vti', vti_path='/data/')

Learning Rate: 5e-05
Batch Size: 2048
Decay Rate: 0.8
Max Epochs: 200
Number of Neurons per Layer: 150
Number of Layers (including input/output): 8
Data Path: /content/new_gmm_isabel_week_6.vti
Output Path: ./models/
Dataset Name: ema
Vti Name: new_gmm_isabel_ema_week_6_predict.vti
train_ema_6rb_150n_2048bs_5e-05lr_Truedecay_0.8dr_decayingAtInterval15
Decay: True
Extracting variables from path: /content/new_gmm_isabel_week_6.vti

n_pts: 3125000 dim: (250, 250, 50) n_dim: 3 total_arr: 9
Total Variables: 9
Univariate: False
Coordinates Shape: (3125000, 3)
Values Shape: (3125000, 9)
Dataset Name: ema
Total Variables



Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 5e-05
    maximize: False
    weight_decay: 0
)
MSELoss()

Learning Rate: 5e-05
Max Epochs: 200
Batch Size: 2048
Number of Hidden Layers: 6
Number of Neurons per Layer: 150
Decay Rate: 0.8
Rate decays every 15 epochs.



  0%|          | 1/200 [00:57<3:09:51, 57.24s/it]

Epoch: 1/200 | Train Loss: 0.036182 | Grad Norm: 0.3411 | Time: 57.24s (cuda) | LR: 5.00e-05
