GMDA is a Python package for generative modeling with density alignment. This README provides instructions for installation, usage, and key features of the package.
Disclaimer: Installing this package will result in the installation of a specific version of PyTorch, which may not be compatible with every user's GPU driver. Before installation, please check the compatibility of the included PyTorch version with your GPU driver. If incompatible, you should create your Python environment with the PyTorch version best suited for your system. Visit the official PyTorch website to find the appropriate installation command for your setup.
python -m venv env_gmda
source env_gmda/bin/activate
pip install .
conda create -n env_gmda python=3.9
conda activate env_gmda
pip install .[conda]
For develpoment mode, use:
pip install -e .[conda]
GMDA provides flexible data processing capabilities through the DataProcessor class.
from gmda.data_utils import DataProcessor
# Define custom data loading and processing functions
def custom_data_loader(train: bool = True, **kwargs):
# Your custom data loading logic here. Should retun a tuple of tabular data (X, y).
pass
def custom_data_processor(data, train: bool = True, **kwargs):
# Your custom data processing logic here. Should retun a tuple of tuples of processed train and test data ((X_train, y_train), (X_test, y_test)).
pass
# Instantiate the DataProcessor
data_processor = DataProcessor(custom_data_loader, custom_data_processor)
# Create dataloaders
train_loader, val_loader, X, y = data_processor.create_dataloaders(batch_size=64, density=0.1)
To train a GMDA model:
from gmda.models import GMDARunner
from gmda.models.gmda.tools import get_config
# Load configuration
config = get_config('path/to/config.json')
# Initialize and train the model
model = GMDARunner(config)
model.train(train_loader, val_loader, X, config['training'])
From a Trained Model:
X_synthetic, y_synthetic = model.generate(y)
X_synthetic, y_synthetic = X_synthetic.numpy(), y_synthetic.numpy()
From a Pretrained Model:
from gmda.models import generate_from_pretrained
X_synthetic, y_synthetic = generate_from_pretrained(
y,
config['model'],
path_pretrained=model.checkpoint_dir,
device=config['model']['device'],
return_as_array=True
)
GMDA provides metrics to evaluate the quality of generated data:
from gmda.metrics import get_corr_error, get_precision_recall
import numpy as np
# Correlation Error
idx = np.random.choice(np.arange(len(X)), size=min(len(X), 1500), replace=False)
corr_error, corr_error_matrix = get_corr_error(X[idx], X_synthetic[idx])
# Precision/Recall
precision, recall = get_precision_recall(X, X_synthetic, nb_nn=config['training']['nb_nn_for_prec_recall'])
GMDA can be run from the command line:
python main.py --dataset '<DATASET>' \
--path_train '<PATH/TO/TRAIN/CSV>' \
--path_test '<PATH/TO/TEST/CSV>' \
--device 'cuda:0' \
--config '<PATH/TO/CONFIG/JSON>' \
--output_dir '<PATH/TO/OUTPUT/RESULTS>' \
--compute_metrics \
--save_generated
For more details on command-line options, run:
python main.py --help
We welcome contributions! Please contact me for more details.
This project is licensed under the MIT License.