In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
import networkx as nx
from fa2 import ForceAtlas2

materials = pd.read_csv("../Data/model_output_all_kpoint_total_dos_epoch_401_432_batches.csv", header=None, index_col=0).T
materials_normalised = materials.apply(lambda col: col/col.sum())

cumul_materials = materials_normalised.apply(lambda col: col.cumsum())

sample_materials_normalised = materials_normalised.sample(n=6, axis='columns', random_state=6)

sample_cumul_materials = sample_materials_normalised.apply(lambda col: col.cumsum())

sample_cols = sample_cumul_materials.columns
sample_idx = sample_cols.copy()
sample_cumul_mat = sample_cumul_materials.to_numpy(dtype=float, na_value=np.nan, copy=False)

sample_diff_mat = np.empty((sample_cumul_mat.shape[1], sample_cumul_mat.shape[1]))
sample_diff_mat.fill(0)

for i in trange(sample_diff_mat.shape[0]):
    for j in range(i+1, sample_diff_mat.shape[1]):
        sample_diff_mat[i][j] = sum(abs(sample_cumul_mat[:,i] - sample_cumul_mat[:,j]))


sample_adj_mat = sample_diff_mat + sample_diff_mat.T
sample_adj_df = pd.DataFrame(sample_adj_mat, index=sample_idx, columns=sample_cols)
print(sample_adj_df)

forceatlas2 = ForceAtlas2(
    # Behavior alternatives
    outboundAttractionDistribution=True,  # Dissuade hubs
    linLogMode=False,  # NOT IMPLEMENTED
    adjustSizes=False,  # Prevent overlap (NOT IMPLEMENTED)
    edgeWeightInfluence=1.0,

    # Performance
    jitterTolerance=1.0,  # Tolerance
    barnesHutOptimize=True,
    barnesHutTheta=1.2,
    multiThreaded=False,  # NOT IMPLEMENTED

    # Tuning
    scalingRatio=2.0,
    strongGravityMode=False,
    gravity=1.0,

    # Log
    verbose=True
)
sample_positions = forceatlas2.forceatlas2(sample_adj_df.to_numpy(), pos=None, iterations=2000)
sample_node_positions = {sample_adj_df.columns.values[i]: sample_positions[i] for i in range(len(sample_adj_df))}
sample_G = nx.from_pandas_adjacency(sample_adj_df)

nx.draw(sample_G, sample_node_positions, with_labels=True)
plt.axis('off')
plt.show()