In [19]:
import numpy as np
import os
import pandas as pd
from DataSim2 import DataSim
from Lin_regression import LinRegression

In [86]:
ds = DataSim(
    n_samples=10000, latent_dim=20, 
    high_dim=100, std_A=5, 
    random_seed=2111, non_linear_ratio=0.5, 
    cross_ratio=0, sparsity=0.5, 
    s2nr=1
    )

In [87]:
x_dict = {
    'latent_x': ds.latent_x,
    'nl_latent_x': ds.latent_x[:, :10],
    'lin_latent_x': ds.latent_x[:, 10:],
    'hd__x': ds.hd_x,
    'nl_hd_x': ds.hd_x[:, :50],
    'lin_hd_x': ds.hd_x[:, 50:],    
    'transformed_hd_x': ds.non_linear_data_noisy,
    'nl_transformed_hd_x': ds.non_linear_data_noisy[:, :50],
    'lin_transformed_hd_x': ds.non_linear_data_noisy[:, 50:]
}

y = ds.y

lr = LinRegression(x_dict, y, True)

scores = lr.get_scores()

print(pd.DataFrame(scores).T)

                             mse       rmse        mae        r2  adjusted_r2
latent_x                0.978047   0.988963   0.787408  0.998153     0.998149
nl_latent_x           311.158337  17.639681  14.091954  0.412335     0.411688
lin_latent_x          223.475478  14.949096  11.947853  0.577936     0.577471
hd__x                   0.978047   0.988963   0.787408  0.998153     0.998134
nl_hd_x               311.158337  17.639681  14.091954  0.412335     0.409323
lin_hd_x              223.475478  14.949096  11.947853  0.577936     0.575772
transformed_hd_x      529.436843  23.009495  18.314035  0.000087    -0.010117
nl_transformed_hd_x   529.436843  23.009495  18.314035  0.000087    -0.005040
lin_transformed_hd_x  281.795663  16.786770  13.344451  0.467791     0.465062


In [79]:
norm_nl = np.linalg.norm(ds.beta[:10])
norm_lin = np.linalg.norm(ds.beta[10:])
print(f'Norm of non-linear part: {norm_nl}')
print(f'Norm of linear part: {norm_lin}')

Norm of non-linear part: 14.992706267886545
Norm of linear part: 17.58935845030641


In [89]:
data = np.genfromtxt('data4/39_sim_10000_1000_50_1_0_0.5_0.1/set_1/data.csv', delimiter=',', skip_header=1)
x = data[:, :-1]
y = data[:, -1]

beta_hat = np.linalg.solve(x.T @ x, x.T @ y)
y_hat = x @ beta_hat

r2 = 1 - np.sum((y - y_hat) ** 2) / np.sum((y - np.mean(y)) ** 2)

print(f"R^2: {r2}")

R^2: 0.5902745382456951


In [90]:
dicts = [
    {"score1": 5, "score2": 7, "score3": 9},
    {"score1": 6, "score2": 8, "score3": 10},
    {"score1": 4, "score2": 6, "score3": 11},
]


df = pd.DataFrame(dicts)
result = df.agg(["mean", "std"]).to_dict()

print(result)

{'score1': {'mean': 5.0, 'std': 1.0}, 'score2': {'mean': 7.0, 'std': 1.0}, 'score3': {'mean': 10.0, 'std': 1.0}}


In [92]:
df = pd.DataFrame(dicts)
df

Unnamed: 0,score1,score2,score3
0,5,7,9
1,6,8,10
2,4,6,11


In [100]:
import pandas as pd

data = {
    "dataset1": {
        "modelA": [
            {"score1": 5, "score2": 7, "score3": 9},
            {"score1": 6, "score2": 8, "score3": 10},
        ],
        "modelB": [
            {"score1": 4, "score2": 6, "score3": 11},
            {"score1": 5, "score2": 7, "score3": 12},
        ],
    },
    "dataset2": {
        "modelA": [
            {"score1": 7, "score2": 9, "score3": 13},
            {"score1": 8, "score2": 10, "score3": 14},
        ],
        "modelB": [
            {"score1": 6, "score2": 8, "score3": 12},
            {"score1": 7, "score2": 9, "score3": 13},
        ],
    }
}

# one df per dataset
dfs = []
for dataset, models in data.items():
    rows = []
    for dicts in models.values():   # ignore model, just pool them
        rows.extend(dicts)
    df = pd.DataFrame(rows)
    df["dataset"] = dataset
    dfs.append(df)

# combine all datasets into one
df_all = pd.concat(dfs, ignore_index=True)

# group by dataset and aggregate
summary = df_all.groupby("dataset").agg(["mean", "std"])

print(summary)


         score1           score2           score3          
           mean       std   mean       std   mean       std
dataset                                                    
dataset1    5.0  0.816497    7.0  0.816497   10.5  1.290994
dataset2    7.0  0.816497    9.0  0.816497   13.0  0.816497


In [101]:
df_all

Unnamed: 0,score1,score2,score3,dataset
0,5,7,9,dataset1
1,6,8,10,dataset1
2,4,6,11,dataset1
3,5,7,12,dataset1
4,7,9,13,dataset2
5,8,10,14,dataset2
6,6,8,12,dataset2
7,7,9,13,dataset2


In [123]:
df = pd.read_csv('data4/108_sim_10000_100_10_0.3_0.3_0.5_1/set_3/data.csv')

In [124]:
np.sum(df.std(axis=0) < 1e-10)

np.int64(1)

In [115]:
df.describe()

Unnamed: 0,Feature 0,Feature 1,Feature 2,Feature 3,Feature 4,Feature 5,Feature 6,Feature 7,Feature 8,Feature 9,Feature 10,Feature 11,Feature 12,Feature 13,Feature 14,Feature 15,Feature 16,Feature 17,Feature 18,Feature 19,Feature 20,Feature 21,Feature 22,Feature 23,Feature 24,Feature 25,Feature 26,Feature 27,Feature 28,Feature 29,Feature 30,Feature 31,Feature 32,Feature 33,Feature 34,Feature 35,Feature 36,Feature 37,Feature 38,Feature 39,Feature 40,Feature 41,Feature 42,Feature 43,Feature 44,Feature 45,Feature 46,Feature 47,Feature 48,Feature 49,Feature 50,Feature 51,Feature 52,Feature 53,Feature 54,Feature 55,Feature 56,Feature 57,Feature 58,Feature 59,Feature 60,Feature 61,Feature 62,Feature 63,Feature 64,Feature 65,Feature 66,Feature 67,Feature 68,Feature 69,Feature 70,Feature 71,Feature 72,Feature 73,Feature 74,Feature 75,Feature 76,Feature 77,Feature 78,Feature 79,Feature 80,Feature 81,Feature 82,Feature 83,Feature 84,Feature 85,Feature 86,Feature 87,Feature 88,Feature 89,Feature 90,Feature 91,Feature 92,Feature 93,Feature 94,Feature 95,Feature 96,Feature 97,Feature 98,Feature 99,Target
count,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0
mean,18.733177,-0.1603477,3.343012,0.018677,0.005643,-0.526594,0.011525,4.736628,0.004028,0.00305,-131.304667,0.000593,215.861933,287784.7,12.628195,1121.947679,3.267156,2.658057,3.665199,3.395374,4.851178,21.541518,3.647579,-0.015711,-0.008489,0.535185,1.3e-05,16.362631,133634.8,576.590809,0.224964,-0.037376,0.040034,0.249761,0.153361,0.152813,0.449469,-0.261433,-0.189564,-0.171597,0.948832,-0.051538,-0.160196,0.325651,0.2319,-0.098513,0.047189,-0.491228,0.419778,-0.334377,-0.020257,0.170557,-0.05155,0.382746,-0.426254,0.219432,0.292124,0.679048,0.717578,-0.449341,0.002037,0.084773,-0.010025,0.280699,-0.134818,0.20407,6.689e-13,0.044039,0.057046,-0.052178,-5.7143e-12,0.098576,-0.136893,-0.065904,0.345282,-0.380426,0.021743,-0.05352,0.184911,0.07845,-0.007684,0.184374,-0.043615,0.526291,-0.041548,0.077546,-0.005903,-0.091948,-4.9037e-12,-0.144025,-0.419119,-1.777e-12,0.154328,-0.3064,-0.13468,-0.03597,-0.073532,0.046665,0.045847,-0.346988,-0.13998
std,19.875074,4.752314e-13,2.651381,1.012748,0.997067,0.185317,0.997891,2.980468,0.993022,0.995405,12195.32923,0.996536,425.428068,1278898.0,13.66796,2230.14328,2.646475,2.178581,3.819957,2.580492,2.82414,22.745726,2.619333,1.004609,1.003064,6421.007454,1.004035,17.521286,667245.3,1139.203798,14.41505,19.879782,10.683346,15.568192,21.933925,17.123032,28.020204,18.001335,12.934465,9.686076,50.310428,10.873277,18.367254,22.993921,26.716966,12.184937,26.415851,26.898796,20.681481,28.937165,9.480192,11.877532,32.128615,23.867638,20.649924,7.606688,42.250905,46.011371,24.576381,39.027178,0.257188,13.445705,16.425166,29.978514,8.292276,30.757072,2.848216e-10,7.079232,5.825112,3.432067,2.813392e-10,12.612303,14.959887,13.591049,35.250294,28.5025,6.150648,6.236439,13.220934,28.714594,14.033891,9.677866,6.400832,32.415102,13.077426,10.20838,17.695734,17.121576,2.793791e-10,18.221383,26.082934,2.841269e-10,10.978688,33.472742,20.154967,15.077379,9.551292,6.334382,22.039309,37.760439,6.374201
min,-50.712254,-0.1603477,-7.161203,-3.412391,-3.627137,-1.093435,-3.381665,-7.440108,-3.037259,-3.267438,-112082.524015,-3.415137,-1087.647057,-3194945.0,-34.161375,-5587.422703,-7.757965,-4.543758,-9.05867,-6.291347,-7.917438,-51.950133,-6.368567,-3.516983,-3.486479,-75521.768076,-3.424966,-37.514703,-1753026.0,-2988.756492,-56.428197,-72.853196,-43.588155,-58.865028,-75.313434,-74.324036,-110.692309,-66.447221,-46.780844,-34.893335,-207.072704,-40.213817,-65.141222,-87.801264,-126.965834,-47.006407,-98.349941,-114.974445,-73.768661,-107.02195,-37.451594,-44.902218,-111.2064,-87.048713,-76.270892,-33.33152,-199.637372,-165.749934,-87.729154,-143.705507,-1.003279,-50.640221,-64.375035,-127.57323,-32.384477,-117.214561,-1.046e-09,-23.84865,-19.411174,-14.668679,-1.033e-09,-53.810061,-58.959214,-61.310866,-128.03775,-114.632013,-24.625714,-23.588906,-53.101232,-106.113442,-51.726259,-37.181137,-26.736246,-116.438186,-50.849453,-35.058079,-92.721164,-61.894239,-1.022e-09,-70.078692,-114.707383,-1.147e-09,-44.049672,-131.993631,-82.847139,-61.436662,-35.7667,-23.671118,-90.030046,-150.220489,-23.545755
25%,4.909801,-0.1603477,1.628117,-0.697492,-0.711069,-0.65292,-0.703865,2.940584,-0.703082,-0.712646,-6791.878144,-0.713065,-65.36032,-464643.2,3.129908,-354.861503,1.565672,1.155786,1.072311,1.685942,3.053838,5.646959,1.939187,-0.73862,-0.738735,-3523.518512,-0.719092,4.371277,-256113.6,-183.587471,-9.354135,-13.548996,-7.143443,-10.239818,-14.581569,-11.610093,-18.217234,-12.712921,-8.91626,-6.793979,-33.168998,-7.374704,-12.485552,-15.078616,-17.443529,-8.331215,-17.548742,-18.596153,-13.632453,-19.878531,-6.358332,-7.935487,-21.741689,-15.768959,-14.443157,-4.874405,-27.8787,-30.38308,-16.060384,-26.873194,-0.176285,-9.1127,-10.912792,-20.292447,-5.809918,-20.847291,-1.93e-10,-4.731507,-3.850598,-2.350284,-1.96e-10,-8.529165,-10.212706,-9.081429,-23.165965,-19.770514,-4.136044,-4.286759,-8.916016,-18.815392,-9.446575,-6.423395,-4.358267,-21.515147,-8.850379,-6.933051,-12.308994,-11.664432,-1.9e-10,-12.545762,-18.107767,-1.96e-10,-7.303584,-23.224518,-13.954456,-10.257968,-6.411313,-4.201857,-14.728126,-25.641413,-4.46802
50%,17.660671,-0.1603477,3.467113,0.011863,0.010386,-0.547166,0.026463,4.891466,-0.000909,0.01329,-166.3729,-0.011672,175.750443,187408.3,11.746213,910.381585,3.414102,2.687649,3.40829,3.545931,5.046673,20.191819,3.778488,-0.008782,-0.014898,3.742078,0.008415,15.210989,83914.74,462.433699,0.141768,0.054453,0.119341,0.170245,0.243115,-0.04852,-0.145025,-0.216556,-0.151947,-0.230089,0.193097,-0.182998,-0.214961,0.331521,0.145103,-0.158314,0.20714,-0.320385,0.43958,-0.149046,0.046628,0.113528,-0.021696,0.613547,-0.672232,0.176148,-0.201066,1.229769,0.702879,-0.704702,0.00129,0.239338,0.028118,-0.119238,-0.122054,0.392529,3e-12,-0.021487,0.033653,-0.022305,-4e-12,0.098706,-0.266631,0.100065,0.559442,0.174156,0.039657,0.023844,0.10577,0.081162,-0.273737,0.186224,0.020855,0.697935,0.022985,0.134661,0.370354,0.043067,-4e-12,-0.068187,-0.282056,-3e-12,0.034564,-0.413699,-0.192931,0.064289,-0.020494,0.019439,0.167483,-0.368315,-0.157484
75%,31.314478,-0.1603477,5.206301,0.739952,0.729909,-0.424323,0.727807,6.791932,0.699333,0.710269,6547.711189,0.709614,442.317132,860466.8,21.207782,2281.218058,5.131325,4.195584,6.101673,5.217873,6.809004,35.94478,5.485654,0.71315,0.710926,3422.937185,0.719487,27.281638,440707.4,1177.063791,9.856447,13.417931,7.357851,10.597972,14.880791,11.809573,18.803559,12.036807,8.762049,6.359907,34.941827,7.325288,12.156278,15.683381,18.124032,8.160911,17.514288,17.841853,14.347667,19.018863,6.398438,8.199924,21.34251,16.475036,13.795662,5.397195,28.666573,32.011838,17.400256,26.37559,0.177362,9.092509,11.088017,20.236286,5.317989,21.113425,1.92e-10,4.875192,4.074412,2.240698,1.88e-10,8.759976,9.963963,9.163306,23.888949,19.050583,4.18496,4.142821,9.200974,19.395682,9.256245,6.666675,4.254985,22.337538,8.781431,7.016483,12.164359,11.506634,1.81e-10,12.350528,17.216763,1.91e-10,7.489257,21.95049,13.360652,10.290495,6.392952,4.270248,15.008679,25.059432,4.208981
max,109.007286,-0.1603477,11.906599,3.076362,3.616124,0.614709,3.261001,14.21305,3.467488,3.427769,111024.049127,3.367477,3070.338638,20491480.0,83.084694,21371.196766,12.316485,10.181427,24.923615,12.443312,14.498968,155.113363,12.378777,3.088478,3.361865,79077.215494,3.220452,106.553617,17398860.0,11418.549693,53.892551,75.358059,40.904433,58.824396,78.624234,67.477412,110.624715,67.478045,54.005652,38.118079,174.800154,41.375289,66.454847,86.445536,92.559947,42.767368,104.138348,107.221245,78.973268,121.915659,41.700761,44.349595,114.78962,97.552493,81.755779,27.505736,164.704551,173.296181,94.641863,169.018358,1.002489,49.151135,66.30271,138.335172,31.64127,116.098726,1.06e-09,25.152584,21.179952,13.882129,1.162e-09,46.846155,54.957389,56.874075,144.890539,105.703805,22.950676,24.457674,52.250048,137.664982,52.84291,34.496741,28.215127,113.909934,45.448701,41.026232,61.903483,63.921138,1.173e-09,74.227249,100.418042,1.267e-09,52.063507,124.423291,77.187419,60.499225,37.902308,23.356801,84.135784,137.284193,21.025214
