In [7]:
# https://scikit-learn.org/1.5/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture
# https://plotly.com/python/t-sne-and-umap-projections/
from sklearn.mixture import GaussianMixture
from sklearn.manifold import TSNE
import plotly.express as px
import numpy as np

In [16]:
n_comp = 10
dataset = np.load('./fineweb_dataset_embeded.npy')

In [17]:
gmm = GaussianMixture(n_comp)
clusters = gmm.fit_predict(dataset)

In [19]:
#TSNE for visualization
tsne = TSNE(n_components=2, random_state=4774)
projected_data = tsne.fit_transform(dataset)

px.scatter(projected_data, color=clusters.astype(np.int32).astype(str), width=1200, height=700, 
           template='plotly_white', title=f'TSNE Visualization, Colored by Baseline GMM with n={n_comp}'
        )


In [22]:
print('Means:')
print(gmm.means_)

Means:
[[-0.00844813  0.00733825  0.03482859 ... -0.00358404 -0.00580507
  -0.00961842]
 [ 0.00791782  0.01567039  0.0082259  ...  0.00082654 -0.01197084
   0.00976766]
 [ 0.01366084  0.01330728 -0.00695871 ...  0.00485464  0.00447715
  -0.00109669]
 ...
 [ 0.00902192  0.00271049 -0.01298192 ...  0.00861321  0.00115718
  -0.00190228]
 [ 0.01492974  0.00537062  0.01319154 ...  0.00254293 -0.00797953
   0.00344092]
 [ 0.01502316  0.00942799 -0.03095092 ... -0.00049978  0.00201509
  -0.00505816]]


In [23]:
print('Covariance Matrix')
print(gmm.covariances_)

Covariance Matrix
[[[ 7.67265699e-04 -8.64242569e-05  1.87111098e-04 ...  6.49634429e-05
   -7.95464773e-06  8.67865164e-05]
  [-8.64242569e-05  7.84848721e-04  2.28228200e-04 ...  3.31638192e-05
   -4.78033984e-05 -1.01069045e-04]
  [ 1.87111098e-04  2.28228200e-04  9.95930782e-04 ...  9.81782880e-05
   -9.34984446e-05  3.10089853e-05]
  ...
  [ 6.49634429e-05  3.31638192e-05  9.81782880e-05 ...  3.11138278e-04
   -1.75913797e-06  2.62154823e-05]
  [-7.95464773e-06 -4.78033984e-05 -9.34984446e-05 ... -1.75913797e-06
    4.82456988e-04  2.44292563e-05]
  [ 8.67865164e-05 -1.01069045e-04  3.10089853e-05 ...  2.62154823e-05
    2.44292563e-05  3.84729364e-04]]

 [[ 6.17193192e-04  4.57746396e-05 -1.53257553e-05 ...  6.16839835e-05
    8.82293425e-07  4.91205027e-05]
  [ 4.57746396e-05  5.87393582e-04  2.10657206e-04 ... -9.63451100e-05
    2.31329006e-05 -2.65227510e-05]
  [-1.53257553e-05  2.10657206e-04  9.00963913e-04 ... -7.62481887e-05
   -7.09275172e-05 -1.76874852e-05]
  ...
  [ 6

In [24]:
print('Cluster Weights')
print(gmm.weights_)

Cluster Weights
[0.09819639 0.11523046 0.09218437 0.11022044 0.10721443 0.0991984
 0.08717435 0.08917836 0.08416834 0.11723447]
