In [1]:
import onnxruntime
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts  
import numpy as np 
import onnx 
import netron 
import tqdm

# Load module 
import yaml 
import os 



### Load configs and the pre-trained model

In [2]:
device      = "cuda"
batch_size  = 64
config_path = "configs/sunset_configs.yaml"


with open(config_path, 'r') as file:
    configs = yaml.safe_load(file)

Pretrained models

In [3]:
onnx_model = onnx.load(os.path.join(configs["model_dir"], "sunset_model.onnx")) 

# Check that the IR is well formed
onnx.checker.check_model(onnx_model)

print(onnx.helper.printable_graph(onnx_model.graph))

graph main_graph (
  %input_image[FLOAT, batch_sizex48x64x64]
  %input_scalar[FLOAT, batch_sizex16]
) initializers (
  %conv1.weight[FLOAT, 48x48x3x3]
  %conv1.bias[FLOAT, 48]
  %batchnorm1.weight[FLOAT, 48]
  %batchnorm1.bias[FLOAT, 48]
  %batchnorm1.running_mean[FLOAT, 48]
  %batchnorm1.running_var[FLOAT, 48]
  %conv2.weight[FLOAT, 96x48x3x3]
  %conv2.bias[FLOAT, 96]
  %batchnorm2.weight[FLOAT, 96]
  %batchnorm2.bias[FLOAT, 96]
  %batchnorm2.running_mean[FLOAT, 96]
  %batchnorm2.running_var[FLOAT, 96]
  %concat.weight[FLOAT, 1024x24592]
  %concat.bias[FLOAT, 1024]
  %dense1.weight[FLOAT, 1024x1024]
  %dense1.bias[FLOAT, 1024]
  %dense2.weight[FLOAT, 15x1024]
  %dense2.bias[FLOAT, 15]
  %dense3.weight[FLOAT, 15x15]
  %dense3.bias[FLOAT, 15]
) {
  %/conv1/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%input_image, %conv1.weight, %conv1.bias)
  %/relu/Relu_output_0 = Relu(%/conv1/Conv_output_0)
  %/batchnorm1/BatchNorma

Params for fine-tuning

In [4]:
with open(os.path.join(configs["model_dir"], 'param_names.txt'), 'r') as f:
    list_text      = f.readlines()   # read into a list
    all_parameters = [list_text[i].rstrip('\n') for i in range(len(list_text))]
all_parameters

['conv1.weight',
 'conv1.bias',
 'batchnorm1.weight',
 'batchnorm1.bias',
 'conv2.weight',
 'conv2.bias',
 'batchnorm2.weight',
 'batchnorm2.bias',
 'concat.weight',
 'concat.bias',
 'dense1.weight',
 'dense1.bias',
 'dense2.weight',
 'dense2.bias',
 'dense3.weight',
 'dense3.bias']

In [5]:
frozen_params = [all_parameters[i] for i in range(10)]
requires_grad = [all_parameters[i] for i in range(10, len(all_parameters))]

In [6]:
frozen_params

['conv1.weight',
 'conv1.bias',
 'batchnorm1.weight',
 'batchnorm1.bias',
 'conv2.weight',
 'conv2.bias',
 'batchnorm2.weight',
 'batchnorm2.bias',
 'concat.weight',
 'concat.bias']

In [7]:
requires_grad

['dense1.weight',
 'dense1.bias',
 'dense2.weight',
 'dense2.bias',
 'dense3.weight',
 'dense3.bias']

#### Offline artifacts generation for training

In [8]:
os.makedirs(configs["artifacts_dir"], exist_ok = True)

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.MSELoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    artifact_directory=configs["artifacts_dir"],
    additional_output_names=["output"])


checkpoint_path    = os.path.join(configs["artifacts_dir"], "checkpoint")
train_model_path   = os.path.join(configs["artifacts_dir"], "training_model.onnx")
eval_model_path    = os.path.join(configs["artifacts_dir"], "eval_model.onnx")
optimizer_path     = os.path.join(configs["artifacts_dir"], "optimizer_model.onnx")


# load checkpoints.
state     = CheckpointState.load_checkpoint(checkpoint_path)

# Create module.
model     = Module(train_model_path, state, eval_model_path )

# Create optimizer.
optimizer = Optimizer(optimizer_path, model)
optimizer.set_learning_rate(configs["learning_rate"])

print(optimizer.get_learning_rate())

2025-04-07 15:15:40.006756182 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 0 with status: OK
2025-04-07 15:15:40.006774436 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusion modified: 0 with status: OK
2025-04-07 15:15:40.006791874 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2025-04-07 15:15:40.006796084 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer GeluFusion modified: 0 with status: OK
2025-04-07 15:15:40.006799979 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2025-04-07 15:15:40.006805045 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion modified: 0 with status: OK
2025-04-07 15:15:40.006808778 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer Qui

9.999999747378752e-06


In [9]:
# visualize the pre-trained model
netron.start(os.path.join(configs["artifacts_dir"], "eval_model.onnx"))

Serving 'sunset_onnex/artifacts/eval_model.onnx' at http://localhost:8081


('localhost', 8081)

## Fine-tuning the model

The following script generates the list of index for training/validation/testing.  

In [10]:
''' 
The following script generates the list of index for training/validation/testing. 

However, we have already generated them for you in 'sirta_data/2023',
and set the number of samples to 13360, which is the total samples for month 05 in year 2023.

To generate the new lists, you need to check the number of samples, and adjust the parameter accordingly.  
Then, just uncomment the following 'Python' code, so that you can generate the file. 
''' 

'''  
import random

data_path   = "sirta_data/2023"
num_samples = 13360
seq = np.arange(num_samples).tolist()

training_samples = int(np.floor(0.90*num_samples))
training_index_list = random.sample(seq, training_samples)

seq_rm_training     = list(set(seq) - set(training_index_list))  
validating_samples  = int(np.floor(0.05*num_samples))
validate_index_list = random.sample(seq_rm_training, validating_samples)

testing_index_list = list(set(seq_rm_training) - set(validate_index_list))    

with open(os.path.join(data_path, "training.txt"), 'w') as file:
    for index, item in enumerate(training_index_list):
        if index == len(training_index_list)-1:
            file.write(str(item) )
        else:
            file.write(str(item) + '\n')

with open(os.path.join(data_path, "validate.txt"), 'w') as file: 
    for index, item in enumerate(validate_index_list):
        if index == len(validate_index_list)-1:
            file.write(str(item) )
        else:
            file.write(str(item) + '\n')

with open(os.path.join(data_path, "testing.txt"), 'w') as file:
    for index, item in enumerate(testing_index_list): 
        if index == len(testing_index_list)-1:
            file.write(str(item) )
        else:
            file.write(str(item) + '\n') 
''' 

'  \nimport random\n\ndata_path   = "sirta_data/2023"\nnum_samples = 13360\nseq = np.arange(num_samples).tolist()\n\ntraining_samples = int(np.floor(0.90*num_samples))\ntraining_index_list = random.sample(seq, training_samples)\n\nseq_rm_training     = list(set(seq) - set(training_index_list))  \nvalidating_samples  = int(np.floor(0.05*num_samples))\nvalidate_index_list = random.sample(seq_rm_training, validating_samples)\n\ntesting_index_list = list(set(seq_rm_training) - set(validate_index_list))    \n\nwith open(os.path.join(data_path, "training.txt"), \'w\') as file:\n    for index, item in enumerate(training_index_list):\n        if index == len(training_index_list)-1:\n            file.write(str(item) )\n        else:\n            file.write(str(item) + \'\n\')\n\nwith open(os.path.join(data_path, "validate.txt"), \'w\') as file: \n    for index, item in enumerate(validate_index_list):\n        if index == len(validate_index_list)-1:\n            file.write(str(item) )\n        e

The following script sets up the data loader for fine-tunning the model

In [11]:
from dataloader import sirta_dataset
from torch.utils.data import Dataset, DataLoader

batch_size = 64

training_dataset = sirta_dataset( mode = "Train",
                                irrad_path  = "sirta_data/2023",
                                image_path  = "sirta_data/2023/images",
                                seq_length  = 16,
                                pred_length = 15,
                                image_size  = 64,
                                batch_size  = batch_size,
                                training_index_file = "sirta_data/2023/training.txt",
                                validate_index_file = "sirta_data/2023/validate.txt",
                                testing_index_file  = "sirta_data/2023/testing.txt") 

valid_dataset = sirta_dataset( mode = "Valid",
                                irrad_path  = "sirta_data/2023",
                                image_path  = "sirta_data/2023/images",
                                seq_length  = 16,
                                pred_length = 15,
                                image_size  = 64,
                                batch_size  = batch_size,
                                training_index_file = "sirta_data/2023/training.txt",
                                validate_index_file = "sirta_data/2023/validate.txt",
                                testing_index_file  = "sirta_data/2023/testing.txt") 
    
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, drop_last=True)  
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True) 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_csv_data["Datetime"] = pd.to_datetime(selected_csv_data["Datetime"])
image transformation: 100%|██████████| 14290/14290 [03:22<00:00, 70.52it/s]


Total number of stacked samples: 13360
For [Train] mode: the number of stacked samples: 12024
               the number of batches: 187


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_csv_data["Datetime"] = pd.to_datetime(selected_csv_data["Datetime"])
image transformation: 100%|██████████| 14290/14290 [01:33<00:00, 152.47it/s]


Total number of stacked samples: 13360
For [Valid] mode: the number of stacked samples: 668
               the number of batches: 10


Training Loop

In [12]:
from myonnxutils.onnx_utils import EarlyStopping, AdjustLR, mean_square_error

early_stopper = EarlyStopping(patience=configs["early_stopper_patience"])
adjustlr      = AdjustLR(patience=1)

for epoch in range(50):
    
    model.train()
    losses = []
    first_rmse_list = []
    last_rmse_list = []
    batch_rmse_list = []
    pbar = tqdm.tqdm(train_loader)
    for _, data_batch in enumerate(pbar): 
        input_index       = data_batch[0]
        input_iclr        = data_batch[1]
        input_skyimage    = data_batch[2]
        output_irr        = data_batch[3] 
    
        forward_inputs = [input_skyimage.float().numpy(), input_iclr.float().numpy(), output_irr.float().numpy()]
        train_loss, pred_irradiance = model(*forward_inputs)
        optimizer.step()
        model.lazy_reset_grad()

        losses.append(train_loss)
        first_rmse , last_rmse, batch_rmse = mean_square_error(pred_irradiance, output_irr.float().numpy())
        first_rmse_list.append(first_rmse)
        last_rmse_list.append(last_rmse)
        batch_rmse_list.append(batch_rmse)
        pbar.set_description("Loss: %.4f" % (sum(losses)/len(losses)))

    train_loss = sum(losses)/len(losses)
    print(f'Epoch: {epoch+1}, Train Loss: {sum(losses)/len(losses):.4f}, First RMSE: {sum(first_rmse_list)/len(first_rmse_list):.4f}, Last RMSE {sum(last_rmse_list)/len(last_rmse_list):.4f}, Batch RMSE {sum(batch_rmse_list)/len(batch_rmse_list):.4f}')
 
 
    model.eval()
    losses = []
    first_rmse_list = []
    last_rmse_list = [] 
    batch_rmse_list = [] 

    vpbar = tqdm.tqdm(valid_loader)
    for _, data_batch in enumerate(vpbar): 
        input_index       = data_batch[0]
        input_iclr        = data_batch[1]
        input_skyimage    = data_batch[2]
        output_irr        = data_batch[3] 

    
        forward_inputs = [input_skyimage.float().numpy(), input_iclr.float().numpy(), output_irr.float().numpy()]
        test_loss, pred_irradiance = model(*forward_inputs)  
        first_rmse , last_rmse, batch_rmse = mean_square_error(pred_irradiance, output_irr.float().numpy())
        first_rmse_list.append(first_rmse)
        last_rmse_list.append(last_rmse)
        batch_rmse_list.append(batch_rmse)
        losses.append(test_loss)
        vpbar.set_description("Loss: %.4f" % (sum(losses)/len(losses)))

    valid_loss = sum(losses)/len(losses)

   # metrics = metric.compute()
    print(f'========== Valid Loss: {valid_loss:.4f}, First RMSE: {sum(first_rmse_list)/len(first_rmse_list):.4f}, Last RMSE {sum(last_rmse_list)/len(last_rmse_list):.4f}, Batch RMSE {sum(batch_rmse_list)/len(batch_rmse_list):.4f}')

    saved_params = {
        "model": model 
    }
    early_stopper(sum(first_rmse_list)/len(first_rmse_list), saved_params, configs)
    adjustlr(valid_loss)

    if adjustlr.do_adjust: 
        current_lr = optimizer.get_learning_rate()
        optimizer.set_learning_rate(0.5*current_lr) 
        print(f'------- Adjust LR: {current_lr:.4f} ==> {0.5*current_lr}')
        

    if early_stopper.early_stop:
        break
 

Loss: 6185.8691: 100%|██████████| 187/187 [00:17<00:00, 10.89it/s]


Epoch: 1, Train Loss: 6185.8691, First RMSE: 70.3813, Last RMSE 87.3075, Batch RMSE 76.4005


Loss: 5608.9600: 100%|██████████| 10/10 [00:00<00:00, 10.31it/s]




Loss: 6174.2090: 100%|██████████| 187/187 [00:17<00:00, 10.96it/s]


Epoch: 2, Train Loss: 6174.2090, First RMSE: 70.3453, Last RMSE 87.6623, Batch RMSE 76.3585


Loss: 5592.9224: 100%|██████████| 10/10 [00:00<00:00, 12.19it/s]




Loss: 6158.6470: 100%|██████████| 187/187 [00:17<00:00, 10.93it/s]


Epoch: 3, Train Loss: 6158.6470, First RMSE: 70.2371, Last RMSE 87.6622, Batch RMSE 76.4045


Loss: 5602.6025: 100%|██████████| 10/10 [00:00<00:00, 11.92it/s]


EarlyStopping counter: 1 out of 3
------- Adjust LR: 0.0000 ==> 4.999999873689376e-06


Loss: 6161.6851: 100%|██████████| 187/187 [00:16<00:00, 11.15it/s]


Epoch: 4, Train Loss: 6161.6851, First RMSE: 69.8727, Last RMSE 87.6479, Batch RMSE 76.3507


Loss: 5595.8203: 100%|██████████| 10/10 [00:00<00:00, 11.74it/s]


------- Adjust LR: 0.0000 ==> 2.499999936844688e-06


Loss: 6147.9404: 100%|██████████| 187/187 [00:17<00:00, 10.90it/s]


Epoch: 5, Train Loss: 6147.9404, First RMSE: 70.4702, Last RMSE 87.7631, Batch RMSE 76.3800


Loss: 5594.2822: 100%|██████████| 10/10 [00:00<00:00, 12.04it/s]


------- Adjust LR: 0.0000 ==> 1.249999968422344e-06


Loss: 6134.5864: 100%|██████████| 187/187 [00:17<00:00, 10.98it/s]


Epoch: 6, Train Loss: 6134.5864, First RMSE: 69.6849, Last RMSE 87.6275, Batch RMSE 76.1080


Loss: 5589.0347: 100%|██████████| 10/10 [00:00<00:00, 11.73it/s]


EarlyStopping counter: 1 out of 3


Loss: 6144.7988: 100%|██████████| 187/187 [00:17<00:00, 10.77it/s]


Epoch: 7, Train Loss: 6144.7988, First RMSE: 69.5998, Last RMSE 87.6347, Batch RMSE 76.2262


Loss: 5588.9639: 100%|██████████| 10/10 [00:00<00:00, 11.60it/s]


EarlyStopping counter: 2 out of 3


Loss: 6134.9136: 100%|██████████| 187/187 [00:17<00:00, 10.65it/s]


Epoch: 8, Train Loss: 6134.9136, First RMSE: 69.6063, Last RMSE 87.5390, Batch RMSE 76.0798


Loss: 5588.3311: 100%|██████████| 10/10 [00:00<00:00, 10.86it/s]

EarlyStopping counter: 3 out of 3





In [13]:
dir(model)

['__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_device',
 '_device_type',
 '_model',
 '_session_options',
 '_state',
 'copy_buffer_to_parameters',
 'eval',
 'export_model_for_inferencing',
 'get_contiguous_parameters',
 'get_parameters_size',
 'input_names',
 'lazy_reset_grad',
 'output_names',
 'train',
 'training']

### Test the fine-tunned model for inference

In [14]:
batch_size   = 16
test_dataset = sirta_dataset( mode = "Test",
                                irrad_path  = "sirta_data/2023",
                                image_path  = "sirta_data/2023/images",
                                seq_length  = 16,
                                pred_length = 15,
                                image_size  = 64,
                                batch_size  = batch_size,
                                training_index_file = "sirta_data/2023/training.txt",
                                validate_index_file = "sirta_data/2023/validate.txt",
                                testing_index_file  = "sirta_data/2023/testing.txt") 
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True) 

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  selected_csv_data["Datetime"] = pd.to_datetime(selected_csv_data["Datetime"])
image transformation: 100%|██████████| 14290/14290 [01:34<00:00, 150.87it/s]


Total number of stacked samples: 13360
For [Test] mode: the number of stacked samples: 668
               the number of batches: 41


In [15]:
from onnxruntime import InferenceSession 

first_rmse_list = []
last_rmse_list  = [] 
batch_rmse_list = []
 
pbar = tqdm.tqdm(test_loader)

print(onnxruntime.get_available_providers()) 

session = InferenceSession(os.path.join(configs["artifacts_dir"], 'inference_model.onnx'),providers=['CPUExecutionProvider'])

input_name1  = session.get_inputs()[0].name
input_name2  = session.get_inputs()[1].name
output_name  = session.get_outputs()[0].name  

for _, data_batch in enumerate(pbar): 

    input_index       = data_batch[0]
    input_iclr        = data_batch[1]
    input_skyimage    = data_batch[2]
    output_irr        = data_batch[3] 



    pred_irradiance_list_of_batches  = session.run(output_names=[output_name], input_feed={input_name1: input_skyimage.float().numpy(), input_name2: input_iclr.float().numpy() })

    
    first_rmse , last_rmse, batch_rmse = mean_square_error(np.concat(pred_irradiance_list_of_batches, axis=0), output_irr.float().numpy())

    first_rmse_list.append(first_rmse)
    last_rmse_list.append(last_rmse)
    batch_rmse_list.append(batch_rmse)
 
 
print(f'========== First RMSE: {sum(first_rmse_list)/len(first_rmse_list):.4f}, Last RMSE {sum(last_rmse_list)/len(last_rmse_list):.4f}, Batch RMSE {sum(batch_rmse_list)/len(batch_rmse_list):.4f}')

  2%|▏         | 1/41 [00:00<00:05,  6.79it/s]

['CUDAExecutionProvider', 'CPUExecutionProvider']


100%|██████████| 41/41 [00:02<00:00, 19.46it/s]




