# Example use of SyGNet

This notebook demonstrates the basic functionality of the *sygnet* pacage in Python.

To download the package, simply run `pip install sygnet` at the command line.

## Prerequisites

First, we will focus on a very simple case of learning a parametric relationship between numeric variables.

To start, we will define a data generating process:

In [16]:
import pandas as pd
import numpy as np
from numpy.random import default_rng

from sygnet import SygnetModel
rng = default_rng()

def gen_sim_data(rng, n=100000):
    x1 = rng.uniform(low = 0, high = 1, size = n)
    x2 = rng.uniform(low = 0, high = 1, size = n)
    x3 = rng.normal(loc = x1 + x2, scale = 0.1)
    y = rng.normal(loc=3*x1 + 2*x2 + 1, scale = 1)
    sim_data = np.column_stack((y,x1,x2,x3)).astype(np.float32)
    sim_data = pd.DataFrame(sim_data)
    return sim_data

sim_data = gen_sim_data(rng)


With this data in hand, we can build our  **SyGNet** model. We follow a very similar pipeline to sci-kit learn: users first instantiate a model, then fit it to the data, and finally sample (i.e. transform) from the fitted model.

In this first example, we set `mode = "wgan"` to use the Wassterstein GAN architecture, which we recommend when not generating conditional labels:

In [17]:
model = SygnetModel(mode = "wgan")
model.fit(data = sim_data, epochs = 10)
synth_data = model.sample(nobs = 100)

synth_data.head()

Epoch:  10%|█         | 1/10 [00:08<01:13,  8.12s/it]


KeyboardInterrupt: 

## GPU support

**sygnet** is written using **pytorch**, and so integrates nicely with GPU computation. To run the synthetic generator on the GPU, simply fit the model with the parameter `device = 'cuda'`:

In [18]:
model_gpu = SygnetModel(mode = "wgan", mixed_activation=True)
model_gpu.fit(data = sim_data, epochs = 10, device='cuda')
synth_data = model_gpu.sample(nobs = 100)

synth_data.head()

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)