In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# change main directory
import sys
sys.path.append('C:/Users/amaguaya/OneDrive - Kienzle Automotive GmbH/Desktop/tesis_code/repos/loci')

import os
print("Working directory:", os.getcwd())

Working directory: c:\Users\amaguaya\OneDrive - Kienzle Automotive GmbH\Desktop\tesis_code\repos\loci\local_run_test


## Main imports

In [7]:
from sklearn.preprocessing import StandardScaler
from causa.loci import loci, loci_w_marginal, compute_marginal_likelihood_nn
from causa.datasets import MNU, Tuebingen, SIM, SIMc, SIMG, SIMln, Cha, Multi, Net
from causa.utils import plot_pair

import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

In [4]:
# initial results
initial_ds = pd.read_csv('../partial_results/lik_scores.csv', sep = ',')

# dataset = MNU(100, preprocessor=None, double=True) # 100
# dataset = Cha(300, preprocessor=None, double=True) # 300
# dataset = Multi(300, preprocessor=None, double=True) # 300
# dataset = Net(300, preprocessor=None, double=True) # 300
# dataset = SIM(100, preprocessor=None, double=True) # 100
# dataset = SIMc(90, preprocessor=None, double=True) # 100
# dataset = SIMG(100, preprocessor=None, double=True) # 100
# dataset = SIMln(100, preprocessor=None, double=True) # 100
# dataset = Tuebingen(108, preprocessor=None, double=True) # 108, but here the just estimated 99

In [5]:
ls_dataset_used = [
            ('MNU', MNU, 100),
            ('Cha', Cha, 300),
            ('Multi', Multi, 300),
            ('Net', Net, 300),
            ('SIM', SIM, 100),
            ('SIMc', SIMc, 100),
            ('SIMG', SIMG, 100),
            ('SIMln', SIMln, 100),
            ('Tuebingen', Tuebingen, 108)
    ]

ls_all_results_marginals = {name: [] for name, _, _ in ls_dataset_used[:1]}

def estimate_sample(ds_class, idx, n_steps, seed=711, device='cpu', preprocessor = None):
    try:
        print(f"start estimation, sample {idx}")
        dataset = ds_class(idx, preprocessor=preprocessor, double=True)
        x = dataset.cause.flatten().numpy()
        y = dataset.effect.flatten().numpy()

        log_mg_x, _, _ = compute_marginal_likelihood_nn(x, n_steps = n_steps, seed=seed, device=device)
        log_mg_y, _, _ = compute_marginal_likelihood_nn(y, n_steps = n_steps, seed=seed, device=device)

        marginal_diff = log_mg_x - log_mg_y
        print(f"successful estimation, sample {idx}")
        return idx, marginal_diff
    except Exception as e:
        print(f"[WARNING] Sample {idx} failed with error: {e}")
        return idx, np.nan

In [9]:
ls_all_results_marginals

{'MNU': []}

In [10]:
# -------------------------------------------
# Main loop: over datasets
# -------------------------------------------
max_workers = 5  # you can increase this depending on CPU
for ds_name, ds_class, n_samples in ls_dataset_used:
    print(f"⏳ Estimating dataset: {ds_name} with {n_samples} samples...***")
    results = []

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                estimate_sample, 
                ds_class = ds_class, idx = idx, 
                n_steps=500, preprocessor=StandardScaler()
                ): idx for idx in range(1, n_samples+1)
            }
        for future in tqdm(as_completed(futures), total=n_samples, desc=f"{ds_name}"):
            idx, marginal_diff = future.result()
            results.append((idx, marginal_diff))

    # Sort by sample index to keep consistent order
    ls_all_results_marginals[ds_name] = sorted(results, key=lambda x: x[0])




⏳ Estimating dataset: MNU with 100 samples...***


Traceback (most recent call last):
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\multiprocessing\queues.py", line 246, in _feed
    send_bytes(obj)
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\multiprocessing\connection.py", line 184, in send_bytes
    self._check_closed()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\multiprocessing\connection.py", line 137, in _check_closed
    raise OSError("handle is closed")
OSError: handle is closed
Traceback (most recent call last):
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\multiprocessing\queues.py", line 246, in _feed
    send_bytes(obj)
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\multiprocessing\connection.py", line 184, in se

OSError: handle is closed

In [None]:
# for ds_name in ls_all_results_marginals.key():
# # We use the MNU pair 55 as an example and without standardized data
# dataset = MNU(55, preprocessor=None, double=True)  # 1000
# x, y = dataset.cause.flatten().numpy(), dataset.effect.flatten().numpy()
# for future in tqdm(as_completed(futures), total=n_samples, desc=f"{ds_name}"):
#     idx, marginal_diff = future.result()
#     results.append((idx, marginal_diff))

MNU:   5%|▌         | 5/100 [00:00<00:00, 4973.09it/s]


In [18]:
initial_ds

Unnamed: 0.1,Unnamed: 0,AN,AN-s,LS,LS-s,MN-U,Cha,Multi,Net,SIM,SIMc,SIMG,SIMln,Tuebingen
0,1,0.389261,0.405831,0.200040,0.394601,0.107985,0.060222,-0.081008,0.207976,-0.390620,-0.358075,-0.027334,0.712540,0.003210
1,2,0.192209,0.098788,0.256571,0.278756,0.176186,-0.020870,-0.059659,0.129776,0.001075,0.011600,-0.000150,-0.010007,0.006439
2,3,0.334150,0.409898,1.280860,0.143243,0.242862,-0.149908,0.905643,1.643242,-0.109915,0.223593,0.002840,0.235984,0.038288
3,4,0.647672,0.284067,0.363610,0.724972,0.043958,-0.124860,0.704741,0.070228,0.162352,0.093823,0.039827,1.447786,-0.108906
4,5,0.427314,0.294612,0.754618,0.217951,0.215037,-0.149560,1.021243,-0.004363,-0.287919,-0.075830,0.019790,0.100354,0.159380
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,296,,,,,,0.120675,0.071921,-0.205145,,,,,
296,297,,,,,,-0.712165,-0.004386,0.043045,,,,,
297,298,,,,,,-0.021618,0.104018,0.018218,,,,,
298,299,,,,,,-0.221457,0.121402,1.068829,,,,,
