In [1]:
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform
import numba

In [2]:
data = pd.read_parquet('../../data/features/DSP_remove_real.parquet', engine='pyarrow')  # You can use 'fastparquet' as the engine
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,29572,29573,29574,29575,29576,29577,29578,29579,Target,Train
0,845.0,365.00,866.5,284.750,277.00,306.00,509.50,510.250,449.00,212.0000,...,449.00,510.250,509.50,306.00,277.00,284.750,866.5,365.00,BA.5.1,2
1,856.0,362.25,870.5,288.500,267.50,299.75,506.75,487.500,422.25,192.1250,...,422.25,487.500,506.75,299.75,267.50,288.500,870.5,362.25,AY.19,2
2,807.0,465.00,815.5,235.625,472.00,143.25,654.50,217.125,559.50,76.5625,...,559.50,217.125,654.50,143.25,472.00,235.625,815.5,465.00,C.1,2
3,843.0,366.75,875.0,271.250,284.75,313.75,489.75,494.500,405.25,199.3750,...,405.25,494.500,489.75,313.75,284.75,271.250,875.0,366.75,L.3,1
4,759.0,453.00,924.5,251.250,272.50,398.75,507.00,544.000,315.00,200.7500,...,315.00,544.000,507.00,398.75,272.50,251.250,924.5,453.00,BA.4.6,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34056,834.0,382.50,878.5,265.000,280.00,332.25,528.00,538.500,424.50,208.1250,...,424.50,538.500,528.00,332.25,280.00,265.000,878.5,382.50,BA.4.1.8,1
34057,811.0,397.00,890.5,271.000,262.75,350.75,494.50,520.000,364.00,190.8750,...,364.00,520.000,494.50,350.75,262.75,271.000,890.5,397.00,BA.4.1.8,2
34058,850.0,357.25,867.0,297.000,251.50,314.75,478.00,497.500,389.50,190.2500,...,389.50,497.500,478.00,314.75,251.50,297.000,867.0,357.25,BA.5.11,2
34059,829.0,386.25,895.0,265.250,264.25,341.75,497.25,518.500,379.00,186.2500,...,379.00,518.500,497.25,341.75,264.25,265.250,895.0,386.25,BA.4,2


In [3]:
targets = data["Target"]
train = data["Train"]

In [4]:
pair_data = data[data["Train"] == 0]

In [5]:
data = data.drop(columns=["Train", "Target"]).to_numpy().astype(np.float32)
pair_data = pair_data.drop(columns=["Train", "Target"]).to_numpy().astype(np.float32)

print("Standardizing Data")
# Function to standardize data
def standardize(data):
    mean = np.mean(data, axis=1, keepdims=True).astype(np.float32)  # Use higher precision for mean and std calculations
    std = np.std(data, axis=1, keepdims=True, ddof=1).astype(np.float32)
    return ((data.astype(np.float32) - mean) / std).astype(np.float32)  # Convert back to float16 after standardization

# Standardize the data and pair_data arrays
data = standardize(data)
pair_data = standardize(pair_data)

print("Computing Dot Product")
@numba.njit
def fast_dot_product(data, pair_data):
    # Ensure the input is float32 for the dot product to minimize precision issues
    result = np.dot(data.astype(np.float32), pair_data.T.astype(np.float32))
    return result.astype(np.float32)  # Convert back to float16 if needed

# Compute the dot product using Numba
dot_product = fast_dot_product(data, pair_data)

n = data.shape[1]  # Number of features
correlation_matrix = (dot_product / (n - 1) + 1) / 2  # Normalize the correlation matrix to be between 0 and 1

Standardizing Data
Computing Dot Product


In [6]:
correlation_matrix = (1 - correlation_matrix) / 2

In [7]:
correlation_matrix = pd.DataFrame(correlation_matrix)

In [8]:
correlation_matrix["Target"] = targets.to_list()
correlation_matrix["Train"] = train.to_list()

In [9]:
correlation_matrix.to_parquet('../../data/features/DSP_dist_real_remove.parquet', engine='pyarrow')

  table = self.api.Table.from_pandas(df, **from_pandas_kwargs)


In [10]:
correlation_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,7087,7088,7089,7090,7091,7092,7093,7094,Target,Train
0,1.305914e-01,0.183110,0.198871,0.220046,0.132452,0.148904,0.213202,0.146502,0.145587,0.164780,...,0.189007,0.181352,0.175472,0.178967,0.133028,0.125790,0.197084,0.214581,BA.5.1,2
1,1.971181e-01,0.220196,0.188837,0.214766,0.193044,0.212138,0.200858,0.207942,0.205332,0.199029,...,0.205630,0.191170,0.217300,0.204568,0.195878,0.207191,0.209397,0.225215,AY.19,2
2,2.039739e-01,0.206309,0.214843,0.226664,0.203169,0.205384,0.224610,0.203596,0.202662,0.209302,...,0.203681,0.198743,0.206360,0.208458,0.204693,0.206207,0.210573,0.223394,C.1,2
3,1.859704e-01,0.206877,0.190235,0.219773,0.181324,0.200331,0.211599,0.198142,0.193632,0.177253,...,0.192095,0.182602,0.216754,0.187768,0.183451,0.205146,0.202923,0.224066,L.3,1
4,1.192093e-07,0.187414,0.187479,0.218131,0.039811,0.138381,0.210451,0.116800,0.110824,0.168169,...,0.179388,0.161821,0.182825,0.138585,0.039080,0.146488,0.195654,0.215954,BA.4.6,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34056,1.669297e-01,0.204979,0.206039,0.218559,0.170828,0.160313,0.212481,0.196542,0.192800,0.165864,...,0.208411,0.197297,0.201681,0.172056,0.170147,0.197481,0.209840,0.217021,BA.4.1.8,1
34057,1.200868e-01,0.195630,0.180903,0.217813,0.132179,0.129969,0.208818,0.163068,0.160390,0.131748,...,0.190813,0.182783,0.179665,0.157758,0.131401,0.169858,0.208695,0.218441,BA.4.1.8,2
34058,1.863979e-01,0.201847,0.197765,0.213053,0.186289,0.168799,0.212713,0.193413,0.196142,0.171635,...,0.205224,0.197828,0.170525,0.190780,0.187672,0.188207,0.214411,0.223217,BA.5.11,2
34059,1.876532e-01,0.199260,0.202925,0.221094,0.187223,0.187037,0.211845,0.190332,0.191009,0.185688,...,0.205026,0.196099,0.187770,0.193998,0.186557,0.203433,0.214639,0.219721,BA.4,2
