# Test autoFRK

**Title**: Test autoFRK Functionality

**Author**: Hsu, Yao-Chih

**Reviewer**: Xie, Yi-Xuan

**Version**: 1141020

**Description**: This script tests the autoFRK python version in different scenarios.

**Reference**: Resolution Adaptive Fixed Rank Kringing by ShengLi Tzeng & Hsin-Cheng Huang

## Install our python autoFRK

In [1]:
import sys
import numpy as np
import torch
print("=" * 50)
print("Python Environment Info")
print("=" * 50)
print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version.split()[0]}")
print()

print("=" * 50)
print("Package Locations")
print("=" * 50)
print(f"Torch location: {torch.__file__}")
print(f"NumPy location: {np.__file__}")
print()

print("=" * 50)
print("PyTorch Info")
print("=" * 50)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA build: {torch.version.cuda}")
print()

print("=" * 50)
print("Test PyTorch Computation")
print("=" * 50)
x = torch.rand(3, 3)
print(f"Random tensor:\n{x}")
print(f"Sum: {x.sum().item():.4f}")

Python Environment Info
Python executable: d:\Github\autoFRK-python\test\.venv-gpu\Scripts\python.exe
Python version: 3.12.10

Package Locations
Torch location: d:\Github\autoFRK-python\test\.venv-gpu\Lib\site-packages\torch\__init__.py
NumPy location: d:\Github\autoFRK-python\test\.venv-gpu\Lib\site-packages\numpy\__init__.py

PyTorch Info
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA build: 12.4

Test PyTorch Computation
Random tensor:
tensor([[0.6049, 0.7359, 0.7075],
        [0.0236, 0.9033, 0.3018],
        [0.2798, 0.1713, 0.7152]])
Sum: 4.4434


In [2]:
import shutil
if shutil.which("dot") is None:
    error_msg = "Graphviz 'dot' executable not found. Please install Graphviz from https://graphviz.org/download/ and ensure it is added to your system PATH. Then restart your computer to apply the changes."
    raise EnvironmentError(error_msg)

In [3]:
# install autoFRK in development mode
import os
import sys
module_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))

!"{sys.executable}" -m pip uninstall -y autoFRK
!"{sys.executable}" -m pip install --upgrade pip build setuptools wheel matplotlib pandas torchviz graphviz
!"{sys.executable}" -m build {module_root}
!"{sys.executable}" -m pip install -e "{module_root}"

Found existing installation: autoFRK 1.2.0
Uninstalling autoFRK-1.2.0:
  Successfully uninstalled autoFRK-1.2.0
* Creating isolated environment: venv+pip...
* Installing packages in isolated environment:
  - setuptools>=61.0
  - wheel
* Getting build dependencies for sdist...
running egg_info
writing src\autoFRK.egg-info\PKG-INFO
writing dependency_links to src\autoFRK.egg-info\dependency_links.txt
writing requirements to src\autoFRK.egg-info\requires.txt
writing top-level names to src\autoFRK.egg-info\top_level.txt
reading manifest file 'src\autoFRK.egg-info\SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'src\autoFRK.egg-info\SOURCES.txt'
* Building sdist...
running sdist
running egg_info
writing src\autoFRK.egg-info\PKG-INFO
writing dependency_links to src\autoFRK.egg-info\dependency_links.txt
writing requirements to src\autoFRK.egg-info\requires.txt
writing top-level names to src\autoFRK.egg-info\top_level.txt
reading manifes

!!

        ********************************************************************************
        Pattern 'LICENCE*' did not match any files.

        By 2026-Mar-20, you need to update your project and remove deprecated calls
        or your builds will no longer be supported.
        ********************************************************************************

!!
  for path in sorted(cls._find_pattern(pattern, enforce_match))
!!

        ********************************************************************************
        Pattern 'LICENCE*' did not match any files.

        By 2026-Mar-20, you need to update your project and remove deprecated calls
        or your builds will no longer be supported.
        ********************************************************************************

!!
  for path in sorted(cls._find_pattern(pattern, enforce_match))
!!

        ********************************************************************************
        Pattern 'LICENCE*' di

Obtaining file:///D:/Github/autoFRK-python
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: autoFRK
  Building editable for autoFRK (pyproject.toml): started
  Building editable for autoFRK (pyproject.toml): finished with status 'done'
  Created wheel for autoFRK: filename=autofrk-1.2.0-0.editable-py3-none-any.whl size=19825 sha256=c9d8229bd962c8ed7f8514793771d5a31351c81c05c2c5e79def585e02ca0b95
  Stored in directory: C:\Users\Yi-Xuan\AppData\Local\Temp\pip-ephem-wheel-cache-3q4euk3m\wheels\b5\c9\34\8b4aeb82

## Import modules

In [4]:
# import modules
import os
import sys
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from autoFRK import AutoFRK
from autoFRK.utils.utils import to_tensor, p

## Version

In [5]:
!pip show autoFRK

Name: autoFRK
Version: 1.0.0
Summary: autoFRK: Automatic Fixed Rank Kriging. The Python version with PyTorch
Home-page: https://github.com/Josh-test-lab/autoFRK-python
Author: ShengLi Tzeng, Hsin-Cheng Huang, Wen-Ting Wang
Author-email: Yao-Chih Hsu <hyc0113@hlc.edu.tw>
License: GNU GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU General Public License is a free, copyleft license for
software and other kinds of works.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users

In [6]:
from importlib.metadata import metadata

meta = metadata("autoFRK")
for key in meta:
    print(f"{key}: {meta[key]}")

Metadata-Version: 2.4
Name: autoFRK
Version: 1.0.0
Summary: autoFRK: Automatic Fixed Rank Kriging. The Python version with PyTorch
Author: ShengLi Tzeng, Hsin-Cheng Huang, Wen-Ting Wang
Author-email: Yao-Chih Hsu <hyc0113@hlc.edu.tw>
License: GNU GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.

                            Preamble

  The GNU General Public License is a free, copyleft license for
software and other kinds of works.

  The licenses for most software and other practical works are designed
to take away your freedom to share and change the works.  By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.  We, the Free Software Foundation, 

## Load data

In [32]:
# load data
datasets_path = f'../test datasets/matrixForTest'
data = pd.read_csv(os.path.join(datasets_path, 'matrixForTest_data.csv'))
locs = pd.read_csv(os.path.join(datasets_path, 'matrixForTest_locs.csv'))
data_missing = pd.read_csv(os.path.join(datasets_path, 'matrixForTest_data_missing.csv'))

In [33]:
# data
data_missing

Unnamed: 0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,...,V21,V22,V23,V24,V25,V26,V27,V28,V29,V30
0,-0.652166,-5.366158,3.143780,,-7.997662,1.136705,0.960326,3.201232,1.992290,2.001078,...,-0.448844,-0.577043,1.269652,0.347800,,2.661689,-0.436828,-3.561380,-2.566957,3.640283
1,,,2.749028,0.729822,-4.332897,-2.829327,,-2.724882,3.281578,-0.628886,...,-3.152476,-0.342614,7.929967,-2.177937,0.556117,-2.043759,4.111130,-5.352733,1.529268,-1.917891
2,-4.549512,-7.763640,-1.872094,2.033659,-5.619806,-2.764594,6.608074,-1.398568,0.612283,-2.144521,...,-11.434087,3.868390,,7.577723,-6.269609,-1.125877,-13.487372,-0.227024,,12.927128
3,-5.870963,-5.904798,-0.655968,7.860149,2.625713,-3.847751,7.439885,2.407433,3.273375,2.692631,...,-7.528925,6.847024,7.220447,4.332599,-3.998274,2.552754,,3.684432,0.137623,7.370367
4,,-8.789871,0.440326,6.932291,-0.442761,-4.145054,5.732851,0.655010,0.159350,-3.153300,...,,4.540188,1.572350,7.350779,-3.499281,-0.169003,-10.041218,1.220481,-3.726887,8.765699
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,2.790659,0.130805,5.131183,,-4.404798,,1.197340,-0.427687,1.653997,4.622978,...,-0.952464,4.506945,-0.397343,-4.158340,0.220648,-1.642949,10.328322,-4.605859,1.894557,
96,1.685645,1.215962,,-3.899476,-1.595553,1.209884,-0.674625,-3.748975,-2.862852,1.977467,...,3.481229,0.753463,-4.839027,-5.032350,3.904153,6.617573,6.643477,3.399834,,
97,-1.395409,-3.042111,1.776225,-0.018072,-8.091349,-0.487496,-5.728769,0.954178,,4.115169,...,-0.294438,1.971506,0.779557,-3.610582,,4.502825,2.828657,-1.532787,5.843775,3.382403
98,-4.765434,-4.096297,-6.376321,3.124180,10.963501,-6.188876,,4.612062,-2.310631,-7.859488,...,-4.794723,1.663460,3.688412,12.393074,,-0.586811,-14.112674,6.708079,-2.698385,7.515730


In [36]:
locs

tensor([[100.3103, 100.5517],
        [100.5172, 100.8276],
        [100.4138, 100.1034],
        [100.5172, 100.3448],
        [100.6897, 100.0345],
        [100.4138, 100.8276],
        [100.6552, 100.3103],
        [100.6552, 100.7241],
        [100.2759, 100.1379],
        [100.0000, 100.9310],
        [100.3793, 100.3103],
        [100.7586, 100.1379],
        [100.7241, 100.4828],
        [100.4828, 100.3103],
        [100.7241, 100.7586],
        [100.5862, 100.0345],
        [100.6897, 100.5517],
        [100.8621, 100.7931],
        [100.0000, 100.5862],
        [100.8276, 100.3103],
        [100.8621, 100.5862],
        [100.0690, 100.7931],
        [100.1034, 100.2414],
        [100.9310, 100.8966],
        [100.5517, 100.8276],
        [100.2759, 100.3793],
        [100.5172, 100.3793],
        [100.4828, 100.7586],
        [100.4138, 100.0345],
        [100.0000, 100.0000],
        [100.9655, 100.0000],
        [100.6552, 100.6552],
        [100.9310, 100.9310],
        [1

## Convert data to tensor

In [35]:
# convert to tensor
data = to_tensor(data.to_numpy())
locs = to_tensor((locs + 100).to_numpy())
data_missing = to_tensor(data_missing.to_numpy())

## Test on known locations

### Fit and predict

In [49]:
model = AutoFRK(
    logger_level=10
)
model.forward(
    data=data,
    loc=locs,
    tps_method=2
)

pred = model.predict()
pred

[32m2025-10-27 16:40:14 - autoFRK.utils.logger - INFO: Calculate TPS with spherical_fast.[0m
[32m2025-10-27 16:40:14 - autoFRK.utils.logger - INFO: Calculate TPS with spherical_fast.[0m


{'pred.value': tensor([[ 0.1392, -4.8864,  4.3818,  ..., -2.9401, -1.2612,  1.1492],
         [-0.0320, -6.5027,  4.4390,  ..., -4.4927,  3.1296, -4.4571],
         [-2.3568, -0.1397, -8.6816,  ..., -3.8314, -6.7192, 12.2255],
         ...,
         [-0.6428, -5.8943,  7.4998,  ...,  1.3138, 10.3329, -1.3814],
         [-1.7356, -2.5282, -2.6529,  ...,  8.3243,  1.5824, -0.9917],
         [ 4.3167,  0.3840,  5.2094,  ...,  0.3785,  2.6360, -2.0869]],
        dtype=torch.float64),
 'se': None}

### Compute MSE

In [50]:
import torch.nn.functional as F

F.mse_loss(pred['pred.value'].cpu(), data.cpu()).item()

13.362707942065565


| tps_method setting | Overall MSE/MSPE (Python) | Overall MSE/MSPE (R) | Conclusion |
| :--- | :--- | :--- | :--- |
| **1. rectangular** | `11.853958747726011` | `11.8539587159961` | The results are highly consistent. |
| **2. sphrical** | `17.376517734701945` | There is no this setting in r autoFRK. | TBA |
| **3. sphrical_fast** | `13.362707942065565` | There is no this setting in r autoFRK. | TBA |


```r
11.8539587159961 # R mse result
```

In [14]:
# Initialize dictionary to store MSE for each time step
mse_per_time_step = {}

for i in range(data.shape[1]):
    y_pred = pred['pred.value'][:, i].cpu()
    y_true = data[:, i].cpu()

    tmp = F.mse_loss(y_pred, y_true)
    # Store the MSE for each time step
    mse_per_time_step[i] = tmp.item()
    print(tmp.item())

7.976630668132241
12.011682111478525
8.008125196422162
3.5373059028827014
13.852875181335873
1.7858863282237092
9.885953708503628
30.04240291322086
27.07296601812779
0.21597662692960537
0.07526554448449785
5.888799359289733
12.59872059602959
11.49324381705484
37.753388832946065
1.7491903035899157
4.247365189475335
6.6476434386807535
4.336457775050738
5.5149082674081455
23.493679801438724
5.361282734455358
10.36399040456185
22.065062283664886
3.2049127751719357
5.186910569055705
30.85543202965407
1.3647053212324893
9.09903479700059
39.928964116902954


### Compare with R result

經過確認，R 與 Python 的結果是大致一致的。
全時間點與各時間點的 MSE ，與 R 的結果皆相差不大，精準度介於小數點後6位元至10位元。

In [15]:
# Load external MSE data for comparison
datasets_path = f'../test datasets/mseForCompare'
mse_each_time_known_locs = pd.read_csv(os.path.join(datasets_path, 'mse_each_time_known_locs.csv'))

# Compare MSE values
a = mse_each_time_known_locs.values
b = np.array(list(mse_per_time_step.values())).reshape(-1, 1)
temp = a - b

# Print differences
for i in range(temp.shape[0]):
    print(f"Time step {i} MSE difference: {temp[i][0]}")

Time step 0 MSE difference: -5.591825136264106e-08
Time step 1 MSE difference: -3.444839240529518e-07
Time step 2 MSE difference: 6.376765782079019e-11
Time step 3 MSE difference: -5.085158161222125e-08
Time step 4 MSE difference: -8.658574301989574e-09
Time step 5 MSE difference: -7.619199271502453e-09
Time step 6 MSE difference: -4.955669830053466e-08
Time step 7 MSE difference: -1.653059911177479e-09
Time step 8 MSE difference: 6.991029977143626e-11
Time step 9 MSE difference: -1.6433723826203561e-09
Time step 10 MSE difference: -1.851784420248137e-10
Time step 11 MSE difference: 3.3431746260248474e-10
Time step 12 MSE difference: -1.3039898050237753e-09
Time step 13 MSE difference: 3.4675560200980726e-08
Time step 14 MSE difference: -4.171226208882217e-08
Time step 15 MSE difference: 2.209543659148494e-10
Time step 16 MSE difference: -1.9867925082905913e-08
Time step 17 MSE difference: -6.277512376584582e-08
Time step 18 MSE difference: 3.605177223420242e-08
Time step 19 MSE differ

### Plotting functions

In [16]:
# for i in range(data.shape[1]):
#     y_pred = pred['pred.value'][:, i].detach().cpu()
#     y_true = data[:, i].detach().cpu()

#     tmp = F.mse_loss(y_pred, y_true)
#     print(tmp)

#     plt.figure(figsize=(8,5))
#     plt.plot(y_true, label='True Value', marker='o')
#     plt.plot(y_pred, label='Predicted Value', marker='x')
#     plt.xlabel('Sample Index')
#     plt.ylabel('Value')
#     plt.title(f'Prediction vs True Value at {i} Step with error {round(tmp.item(), 2)}')
#     plt.legend()
#     plt.grid(True)
#     plt.show()

### Check gradient tracking

In [17]:
# from torchviz import make_dot
# SAVE_DIR = "gradient tracking/all known locations"
# os.makedirs(SAVE_DIR, exist_ok=True)

# for k, v in model.obj.items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         file_path = os.path.join(SAVE_DIR, k)
#         dot = make_dot(v, show_attrs=False, show_saved=False)
#         dot.attr(dpi='100')
#         dot.render(file_path, format='pdf')

# for k, v in model.obj['G'].items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         file_path = os.path.join(SAVE_DIR, k)
#         dot = make_dot(v, show_attrs=False, show_saved=False)
#         dot.attr(dpi='100')
#         dot.render(file_path, format='pdf')

# for k, v in pred.items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         file_path = os.path.join(SAVE_DIR, k)
#         dot = make_dot(v, show_attrs=False, show_saved=False)
#         dot.attr(dpi='100')
#         dot.render(file_path, format='pdf')

## Test on unknown locations
### Train:Test = 7:3

In [18]:
## split data into train and test
train_data = data[:70, :]
test_data = data[70:, :]
train_locs = locs[:70, :]
test_locs = locs[70:, :]

### Fit and predict

In [19]:
## training on train data
model = AutoFRK()
model.forward(
    data=train_data,
    loc=train_locs,
    requires_grad=True
)

## predict on test data
pred = model.predict(
    newloc = test_locs
)

[32m2025-10-27 16:29:56 - autoFRK.utils.logger - INFO: Gradient tracking has been enabled for autoFRK.[0m
[32m2025-10-27 16:29:56 - autoFRK.utils.logger - INFO: Calculate TPS with rectangular.[0m


### Compute MSE

In [20]:
# evaluate mse on test data
F.mse_loss(pred['pred.value'].cpu(), test_data.cpu()).item()

18.888619455132325

In [21]:
# Initialize dictionary to store MSPE for each time step
mspe_per_time_step = {}

for i in range(test_data.shape[1]):
    y_pred = pred['pred.value'][:, i].cpu()
    y_true = test_data[:, i].cpu()

    tmp = F.mse_loss(y_pred, y_true)
    # Store the MSE for each time step
    mspe_per_time_step[i] = tmp.item()
    print(tmp.item())

11.439408970380754
25.107778178701714
27.450073363075305
11.824834253939056
14.635019965032132
9.478019542839203
13.727797085435684
12.17243613570705
17.848309463210754
11.41419537370464
16.1900735581536
12.385319828043881
10.433204064648319
41.77850737083709
38.01213741613706
27.143946173924853
9.501057893334155
18.938319339768736
17.35661542130891
19.450701190034582
15.180028640916511
17.513520186980326
15.105158817105114
35.53883333424458
21.32619819557542
9.027353235450725
30.130724139200915
19.640067721809594
12.26785218998604
24.641092604483084


### Compare with R result

經過確認，R 與 Python 的結果是大致一致的。
全時間點與各時間點的 MSE ，與 R 的結果皆相差不大，精準度介於小數點後6位元至10位元。

In [22]:
# Load external MSPE data for comparison
datasets_path = f'../test datasets/mseForCompare'
mspe_each_time_unknown_locs = pd.read_csv(os.path.join(datasets_path, 'mspe_each_time_unknown_locs.csv'))

# Compare MSPE values
a = mspe_each_time_unknown_locs.values
b = np.array(list(mspe_per_time_step.values())).reshape(-1, 1)
temp = a - b

# Print differences
for i in range(temp.shape[0]):
    print(f"Time step {i} MSPE difference: {temp[i][0]}")

Time step 0 MSPE difference: 1.5363404592960705e-07
Time step 1 MSPE difference: 3.7310618594688094e-07
Time step 2 MSPE difference: -1.1042406100614244e-08
Time step 3 MSPE difference: 9.999204486632607e-08
Time step 4 MSPE difference: 2.2927706844200202e-07
Time step 5 MSPE difference: 1.0739272759963114e-07
Time step 6 MSPE difference: 6.891901627170682e-08
Time step 7 MSPE difference: -1.0218149881779937e-08
Time step 8 MSPE difference: -6.544524921991979e-10
Time step 9 MSPE difference: 1.2479560496103659e-08
Time step 10 MSPE difference: 9.967198622007345e-09
Time step 11 MSPE difference: -8.022881914371283e-09
Time step 12 MSPE difference: 3.136398163405829e-08
Time step 13 MSPE difference: 1.4478440846232843e-07
Time step 14 MSPE difference: 6.104570431375578e-07
Time step 15 MSPE difference: 4.05845455020426e-08
Time step 16 MSPE difference: 4.400175512841997e-08
Time step 17 MSPE difference: 2.2016176259853637e-07
Time step 18 MSPE difference: -7.642080745995372e-08
Time step

### Plotting functions

In [23]:
# for i in range(test_data.shape[1]):
#     y_pred = pred['pred.value'][:, i].detach().cpu()
#     y_true = test_data[:, i].detach().cpu()

#     tmp = F.mse_loss(y_pred, y_true)
#     print(tmp.item())

#     plt.figure(figsize=(8,5))
#     plt.plot(y_true, label='True Value', marker='o')
#     plt.plot(y_pred, label='Predicted Value', marker='x')
#     plt.xlabel('Sample Index')
#     plt.ylabel('Value')
#     plt.title(f'Prediction vs True Value at {i} Step with error {round(tmp.item(), 2)}')
#     plt.legend()
#     plt.grid(True)
#     plt.show()

### Check gradient tracking

In [24]:
# from torchviz import make_dot
# SAVE_DIR = "gradient tracking/with unknown locations"
# os.makedirs(SAVE_DIR, exist_ok=True)

# for k, v in model.obj.items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         # file_path = os.path.join(SAVE_DIR, k)
#         # dot = make_dot(v)
#         # dot.attr(dpi='300')
#         # dot.render(file_path, format='png')

# for k, v in model.obj['G'].items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         # file_path = os.path.join(SAVE_DIR, k)
#         # dot = make_dot(v)
#         # dot.attr(dpi='300')
#         # dot.render(file_path, format='png')

# for k, v in pred.items():
#     if isinstance(v, torch.Tensor):
#         print(f"{k}: requires_grad={v.requires_grad};  grad_fn={v.grad_fn}; is_leaf={v.is_leaf}")
#         # file_path = os.path.join(SAVE_DIR, k)
#         # dot = make_dot(v)
#         # dot.attr(dpi='300')
#         # dot.render(file_path, format='png')

## Test on missing data (EM)

### Fit and predict

In [25]:
# training on data with missing values
model = AutoFRK(
    logger_level=20
)
model.forward(
    data=data_missing,
    loc=locs,
    method='EM',
    requires_grad=True,
    maxit=18
)

# predict on data with missing values
pred = model.predict(
    obj = model.obj
)
pred

[32m2025-10-27 16:29:56 - autoFRK.utils.logger - INFO: Gradient tracking has been enabled for autoFRK.[0m
[32m2025-10-27 16:29:56 - autoFRK.utils.logger - INFO: Calculate TPS with rectangular.[0m
[32m2025-10-27 16:30:00 - autoFRK.utils.logger - INFO: Number of iteration: 18[0m


{'pred.value': tensor([[-3.9923e-04, -3.9498e+00,  4.0037e+00,  ..., -1.4526e+00,
          -3.7519e-01,  4.1146e+00],
         [-2.9209e-01, -3.0439e+00,  2.9015e+00,  ..., -9.5498e-01,
           2.6059e-01,  2.9861e+00],
         [-6.3041e+00, -8.5653e+00,  7.0040e-01,  ...,  3.5755e-01,
          -2.5926e+00,  1.3498e+01],
         ...,
         [ 1.6149e+00, -1.1902e+00,  3.4781e+00,  ..., -1.4641e+00,
           6.4208e-01, -2.6462e-01],
         [-7.2337e+00, -4.5718e+00, -5.4361e+00,  ...,  1.9453e+00,
          -3.4074e+00,  9.6046e+00],
         [-8.9564e-01, -7.6079e-01, -1.4871e-02,  ...,  2.4362e-01,
           5.3259e-01,  8.8942e-01]], dtype=torch.float64,
        grad_fn=<AddBackward0>),
 'se': None}

### Compute MSE

In [26]:
F.mse_loss(pred['pred.value'].cpu(), data.cpu()).item()

5.115453562007837

In [27]:

# Initialize dictionary to store MSE for each time step
mse_per_time_step = {}

# Each time step MSE
for i in range(data.shape[1]):
    y_pred = pred['pred.value'][:, i].cpu()
    y_true = data[:, i].cpu()

    tmp = F.mse_loss(y_pred, y_true)
    # Store the MSE for each time step
    mse_per_time_step[i] = tmp.item()
    print(f"Time step {i} MSE: {tmp.item()}")

Time step 0 MSE: 4.87591310173512
Time step 1 MSE: 4.36952470964358
Time step 2 MSE: 4.602030801878972
Time step 3 MSE: 6.587195417893313
Time step 4 MSE: 5.064184366648315
Time step 5 MSE: 5.202673232392881
Time step 6 MSE: 5.3724275942811985
Time step 7 MSE: 5.224291304748667
Time step 8 MSE: 5.3981614698822
Time step 9 MSE: 4.887758970522188
Time step 10 MSE: 5.384249154673589
Time step 11 MSE: 4.882234535267592
Time step 12 MSE: 3.452850525307342
Time step 13 MSE: 5.230932513007399
Time step 14 MSE: 5.776500950614442
Time step 15 MSE: 5.168875908794735
Time step 16 MSE: 5.5766354109483025
Time step 17 MSE: 5.836962497351717
Time step 18 MSE: 5.143071764861567
Time step 19 MSE: 5.8278109832192175
Time step 20 MSE: 5.915852331903822
Time step 21 MSE: 4.021226662024549
Time step 22 MSE: 4.681773734928682
Time step 23 MSE: 5.7275508658542265
Time step 24 MSE: 3.4086516730412497
Time step 25 MSE: 4.571510938322366
Time step 26 MSE: 4.726495724857156
Time step 27 MSE: 5.61636017401451
Ti

### Compare with R result

經過確認，R 與 Python 的結果差異甚遠。Josh表示R版本的EM實作與Python版本在某些細節上有所不同，導致最終結果的差異。

In [28]:
# Load external MSE data for comparison
datasets_path = f'../test datasets/mseForCompare'
mse_each_time_em = pd.read_csv(os.path.join(datasets_path, 'mse_each_time_em.csv'))

# Compare MSE values
a = mse_each_time_em.values
b = np.array(list(mse_per_time_step.values())).reshape(-1, 1)
temp = a - b

# Print differences
for i in range(temp.shape[0]):
    print(f"Time step {i} MSE difference: {temp[i][0]}")

Time step 0 MSE difference: 11.627957652714482
Time step 1 MSE difference: 9.70860586818732
Time step 2 MSE difference: -2.897410894692252
Time step 3 MSE difference: -5.7415280099303105
Time step 4 MSE difference: -1.4878969730327452
Time step 5 MSE difference: -1.240664212621601
Time step 6 MSE difference: -0.6927686900454182
Time step 7 MSE difference: 0.5054460793620432
Time step 8 MSE difference: 2.19614751859251
Time step 9 MSE difference: -3.300681023231758
Time step 10 MSE difference: -4.646403897782375
Time step 11 MSE difference: -1.7704413991467125
Time step 12 MSE difference: -2.8574451098388667
Time step 13 MSE difference: -0.6876608602761092
Time step 14 MSE difference: -3.759954945571742
Time step 15 MSE difference: 2.5917414576829243
Time step 16 MSE difference: -3.0225149488788023
Time step 17 MSE difference: 6.3116201272843835
Time step 18 MSE difference: -1.650134761554217
Time step 19 MSE difference: 0.32448446082705207
Time step 20 MSE difference: 76.00323994235077