# Synthetic Data Generation with GMM â€” User-Defined Parameters

This notebook demonstrates how to generate synthetic data using a Gaussian Mixture Model (GMM) when the model parameters are explicitly defined by the user.

In this example, we manually specify:

- Component means (`mus`)
- Covariance matrices (`covs`)
- Cluster probabilities (`cluster_probs`)

---

## ðŸ“Œ Workflow Overview

1. Instantiate the `SyntheticGMMGenerator`
   - Define GMM parameters manually
   - Set parameters using `set_params(...)`
   - Generate synthetic samples using `generate_data(...)`
2. Instantiate the `ScatterPlotVisualizer`
   - Visualize the generated data using:
     - PCA projection
     - Scatter Plot Matrix (SPLOM)

---

## ðŸ§ª Experiment: PAAS (Post-Archean Australian Shale), NASC (North American Shale Composite), UCC (Upper Continental Crust), Greywacke mÃ©dia, Carbonato marinho mÃ©dio

In this experiment, synthetic datasets are constructed to represent volcanic rocks from classical geochemical diagrams. Instead of estimating parameters from real geochemical databases, we define Gaussian components that approximate the central tendency and dispersion of rock groups as they appear in:

- **Kâ‚‚O vs SiOâ‚‚ diagram** (Peccerillo & Taylor, 1976)  
- **AFM diagram** (Irvine & Baragar, 1971)  
- **TAS diagram** (Total Alkali vs Silica)


<!-- This is especially helpful for validating classification methods, and benchmarking machine learning models. -->

<!-- ## ðŸŽ¯ Why Use User-Defined Parameters?

Defining parameters explicitly allows you to:

- Control cluster separation
- Define anisotropic covariance structures
- Simulate imbalanced datasets
- Create structured test scenarios -->


In [None]:
import numpy as np
import pandas as pd
import sys
import os
sys.path.append(os.path.abspath(".."))
from main import SyntheticGMMGenerator
from visualization import ScatterPlotVisualizer

In [None]:
# Common sedimentary rocks mean compositions (in wt%)

# Order:
# [SiO2, Al2O3, Fe2O3, MgO, CaO, Na2O, K2O]

means = {
    "shale": np.array([60.0, 17.0, 7.0, 3.0, 2.0, 1.5, 3.5]),
    "ucc": np.array([66.0, 15.0, 5.0, 2.5, 4.0, 3.0, 2.8]),
    "greywacke": np.array([68.0, 14.0, 4.5, 2.5, 3.0, 2.5, 2.5]),
    "carbonate": np.array([5.0, 1.0, 0.5, 5.0, 45.0, 0.2, 0.2])
}


In [None]:
# ðŸ”¹ Matrizes de CovariÃ¢ncia (Î£)
# Unidades: (% wt)Â²
# ConstruÃ­das para refletir:

# Shale â†’ variÃ¢ncia moderada, correlaÃ§Ã£o Alâ€“K positiva
# UCC â†’ variÃ¢ncia menor
# Greywacke â†’ mais espalhado em SiO2
# Carbonato â†’ alta variÃ¢ncia MgO, baixa em CaO

In [None]:
cov_shale = np.array([
    [12.0, -3.0, -2.0, -1.0, -1.0, -0.5, -1.5],
    [-3.0, 4.0, 1.5, 0.8, 0.5, 0.3, 1.2],
    [-2.0, 1.5, 3.0, 1.0, 0.5, 0.2, 0.6],
    [-1.0, 0.8, 1.0, 2.5, 0.8, 0.1, 0.5],
    [-1.0, 0.5, 0.5, 0.8, 3.5, 0.1, 0.2],
    [-0.5, 0.3, 0.2, 0.1, 0.1, 0.8, 0.2],
    [-1.5, 1.2, 0.6, 0.5, 0.2, 0.2, 1.5]
])

cov_ucc = np.array([
    [10.0, -2.5, -1.5, -0.8, -1.2, -0.5, -1.0],
    [-2.5, 3.0, 1.0, 0.5, 0.3, 0.2, 0.8],
    [-1.5, 1.0, 2.0, 0.8, 0.4, 0.2, 0.5],
    [-0.8, 0.5, 0.8, 1.5, 0.6, 0.1, 0.3],
    [-1.2, 0.3, 0.4, 0.6, 2.5, 0.2, 0.2],
    [-0.5, 0.2, 0.2, 0.1, 0.2, 1.2, 0.3],
    [-1.0, 0.8, 0.5, 0.3, 0.2, 0.3, 1.2]
])

cov_greywacke = np.array([
    [15.0, -4.0, -2.0, -1.0, -1.0, -0.8, -1.5],
    [-4.0, 4.0, 1.2, 0.7, 0.5, 0.3, 1.0],
    [-2.0, 1.2, 2.5, 0.9, 0.4, 0.2, 0.6],
    [-1.0, 0.7, 0.9, 2.0, 0.7, 0.2, 0.4],
    [-1.0, 0.5, 0.4, 0.7, 2.8, 0.2, 0.2],
    [-0.8, 0.3, 0.2, 0.2, 0.2, 1.5, 0.4],
    [-1.5, 1.0, 0.6, 0.4, 0.2, 0.4, 1.8]
])

cov_carbonate = np.array([
    [5.0, 0.5, 0.2, -1.0, -2.5, 0.0, 0.0],
    [0.5, 0.5, 0.2, 0.2, -0.3, 0.0, 0.0],
    [0.2, 0.2, 0.3, 0.3, -0.2, 0.0, 0.0],
    [-1.0, 0.2, 0.3, 8.0, -3.0, 0.0, 0.0],
    [-2.5, -0.3, -0.2, -3.0, 6.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1]
])

In [None]:
mus = np.stack([
    means["shale"],
    means["ucc"],
    means["greywacke"],
    means["carbonate"]
])

covs = np.stack([
    cov_shale,
    cov_ucc,
    cov_greywacke,
    cov_carbonate
])

cluster_probs = np.array([0.35, 0.25, 0.25, 0.15])
