In [None]:
# Note: For optimal performance use A100 GPU
!pip install synthcity
!pip install pycox
!pip install be_great

In [None]:
!pip uninstall -y torchaudio torchdata
from synthcity.metrics import Metrics
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.plugins import Plugins
import numpy as np
import pandas as pd
import logging
from sklearn import datasets
import matplotlib.pyplot as plt
from pycox import datasets

In [None]:
from be_great import GReaT

In [None]:
# data=pd.read_csv('/path/flchain_final.csv')
# data=data.drop('Unnamed: 0',axis=1)
# data = data[data['duration'] != 0]

data = datasets.gbsg.read_df()
data = data[data['duration'] != 0]

# Use any dataset that you want to use

In [None]:
data.head()

In [None]:
# from peft import LoraConfig
great = GReaT("distilgpt2",                  # Name of the large language model used (see HuggingFace for more options)
              epochs=1000,                   # Number of epochs to train
              save_steps=10000,               # Save model weights every x steps
              logging_steps=500,             # Log the loss and learning rate every x steps
              experiment_dir="trainer_iris", # Name of the directory where all intermediate steps are saved
              batch_size=32,                 # Batch Size
              # efficient_finetuning='lora'
              #lr_scheduler_type="constant", # Specify the learning rate scheduler
              #learning_rate=5e-5            # Set the inital learning rate
             )

In [None]:
trainer = great.fit(data)

In [None]:
great.save('/path/Great_gbsg_1000')

In [None]:
loss_hist = trainer.state.log_history.copy()
loss_hist.pop()
loss = [x["loss"] for x in loss_hist]
epochs = [x["epoch"] for x in loss_hist]

plt.plot(epochs, loss)

In [None]:
# Generate unconditional samples
n_samples = len(data)
samples = great.sample(n_samples, k=50)

In [None]:
samples.head()

In [None]:
# Benchmark Generate Data

loader1 = SurvivalAnalysisDataLoader(data, target_column="event", time_to_event_column="duration")
loader2 = SurvivalAnalysisDataLoader(samples, target_column="event", time_to_event_column="duration")

met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence',
                 'max_mean_discrepancy', 'wasserstein_dist', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb']
    }, use_cache=False)

In [None]:
met_df

In [None]:
#visualization

from sklearn.manifold import TSNE
import seaborn as sns

real_data=data.dropna()
synthetic=samples.dropna()
combined_data = pd.concat([real_data, synthetic], ignore_index=True)

covariates = combined_data.drop(['duration', 'event'], axis=1)

tsne = TSNE(n_components=2, random_state=0)
tsne_result = tsne.fit_transform(covariates)
event_type = combined_data['event']
tsne_df = pd.DataFrame(data={'TSNE1': tsne_result[:, 0], 'TSNE2': tsne_result[:, 1], 'Event_Type': event_type})

sns.scatterplot(x='TSNE1', y='TSNE2', hue='Event_Type', data=tsne_df[0:1904], palette='viridis')
sns.scatterplot(x='TSNE1', y='TSNE2', hue='Event_Type', data=tsne_df[1904:])
plt.title('t-SNE Plot of Covariates based on Event Type (Original and Synthetic)')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 6))
sns.histplot(data['x0'], bins=50, kde=False, label='Original Data', color='blue')
sns.histplot(samples['x0'], bins=50, kde=False, label='Synthetic Data', color='orange')
plt.title('Comparison of Covariate Distributions')
plt.xlabel('Covariate Values')
plt.ylabel('Frequency')
plt.legend()
plt.show()

In [None]:
from sklearn.mixture import BayesianGaussianMixture

def fit_dpmm_and_sample(data, sample_size, integer_sampling=False, bandwidth='scott'):
   """
   Fit DPMM to data and sample from it
   Parameters:
   -----------
   data : array-like
       Input data to fit DPMM
   sample_size : int
       Number of samples to generate
   integer_sampling : bool, default=False
       If True, returns integer samples. If False, returns continuous samples
   bandwidth : str or float, default='scott'
       Not used for DPMM but kept for API consistency
   Returns:
   --------
   array-like
       Sampled values from the fitted DPMM
   """
   # Reshape data for DPMM
   data = data.reshape(-1, 1)

   # Fit DPMM
   dpmm = BayesianGaussianMixture(
       n_components=10,  # Max number of components
       weight_concentration_prior=1.0,
       random_state=42
   )
   dpmm.fit(data)

   if integer_sampling:
       # Sample more points than needed to account for rounding and filtering
       oversampling_factor = 1.5
       samples = dpmm.sample(int(sample_size * oversampling_factor))[0].reshape(-1)
       # Round to nearest integer and ensure positive
       samples = np.round(np.abs(samples))
       # Convert to integers
       samples = samples.astype(int)
       # Remove any zeros
       samples = samples[samples > 0]
       # If we have more samples than needed due to oversampling, randomly select
       if len(samples) > sample_size:
           samples = np.random.choice(samples, size=sample_size, replace=False)
       # If we have fewer samples than needed, resample with replacement
       elif len(samples) < sample_size:
           samples = np.random.choice(samples, size=sample_size, replace=True)
   else:
       # Direct sampling for continuous values
       samples = dpmm.sample(sample_size)[0].reshape(-1)
       # Ensure all samples are positive
       samples = np.abs(samples)

   return samples

In [None]:
# sample t and e using DPMM

survival_df=df
event_0_data = survival_df[survival_df['event'] == 0]['duration'].values
event_1_data = survival_df[survival_df['event'] == 1]['duration'].values

# Use KDE to sample time values for each event type
sample_size_0 = len(event_0_data)
sample_size_1 = len(event_1_data)

# Sample from KDE for each event type
sample_event_0 = fit_dpmm_and_sample(event_0_data, sample_size_0,integer_sampling=True)
sample_event_1 = fit_dpmm_and_sample(event_1_data, sample_size_1,integer_sampling=True)

z=np.concatenate([sample_event_0,sample_event_1])
x=np.zeros(len(sample_event_0))
y=np.ones(len(sample_event_1))
p=np.concatenate([x,y])

## uncomment below sample t and e empirically
# survival_df=data
# event_0_data = survival_df[survival_df['event'] == 0]['duration']
# event_1_data = survival_df[survival_df['event'] == 1]['duration']

# sample_size_0 = len(event_0_data)
# sample_size_1= len(event_1_data)
# sample_event_0 = np.random.choice(event_0_data, size=sample_size_0)
# sample_event_1 = np.random.choice(event_1_data, size=sample_size_1)

# z=np.concatenate([sample_event_0,sample_event_1])
# x=np.zeros(len(sample_event_0))
# y=np.ones(len(sample_event_1))
# p=np.concatenate([x,y])

In [None]:
import pandas as pd
# new_df = data[['event', 'duration']]
new_df = data[['event', 'duration','x0']]
new_df= new_df.drop(['event','duration'],axis=1)
new_df['duration'] = z
new_df['event'] = p
new_df=new_df.drop('x0',axis=1)

# generate sentences for conditionining

def dataframe_to_text_df(df):
    text_data = []
    for index, row in df.iterrows():
        row_text = ""
        for col_name, col_value in row.items():
            row_text += f"{col_name} is {col_value},"
        text_data.append(row_text[:-1])  # Remove the trailing comma
    text_df = pd.DataFrame({'text': text_data})
    return text_df, text_data

resulting_text_df, str_list = dataframe_to_text_df(new_df)

In [None]:
df=great.great_sample(
    starting_prompts=str_list,
    # temperature: float = 0.7,
    # max_length: int = 100,
    # device: str = 'cuda'
)

In [None]:
df.head()

In [None]:
df = df.drop(df[df.isin(['placeholder']).any(axis=1)].index)
df = df.astype(data.dtypes)
df['event'] = pd.to_numeric(df['event'], errors='coerce')
df['event'] = df['event'].round().astype('Int64')
diff=len(data)-len(df)
drop_indices = data.sample(diff).index
data = data.drop(drop_indices)

In [None]:
df

In [None]:
#Benchmark generated conditional data

loader1 = SurvivalAnalysisDataLoader(data, target_column="event", time_to_event_column="duration")
loader2 = SurvivalAnalysisDataLoader(df, target_column="event", time_to_event_column="duration")

met_df = Metrics.evaluate(X_gt=loader1, X_syn=loader2, task_type='survival_analysis', metrics={
        'stats': ['jensenshannon_dist', 'chi_squared_test', 'feature_corr', 'inv_kl_divergence', 'ks_test',
                 'max_mean_discrepancy', 'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'],
        'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
    }, use_cache=False)

met_df