# Import packages

In [None]:
import pandas as pd
from sdmetrics.reports.utils import get_column_plot, get_column_pair_plot
from sdmetrics.reports.single_table import QualityReport
from sdv import Metadata
import seaborn as sns
import matplotlib.pyplot as plt

# Load data

In [None]:
df_real = pd.read_csv('data/olympics.csv')
df_fake = pd.read_csv('generations/olympics_generation.csv')

# Define column types

In [None]:
continuous_columns = ['Age','Height', 'Weight']
categorical_columns = ['Sex', 'Year', 'Season', 'City', 'Sport', 'Medal', 'AOS', 'AOE']
df_real[continuous_columns] = df_real[continuous_columns].astype('int64')
df_real[categorical_columns] = df_real[categorical_columns].astype('category')
df_fake[continuous_columns] = df_fake[continuous_columns].astype('int64')
df_fake[categorical_columns] = df_fake[categorical_columns].astype('category')

# Add unique key for real and fake data

In [None]:
df_real_plot = df_real.reset_index().rename(columns={'index': 'key'})
df_fake_plot = df_fake.reset_index().rename(columns={'index': 'key'})

# Set up metadata

In [None]:
metadata = Metadata()
metadata.add_table(name='olympic',
                  data=df_real_plot,
                  primary_key = 'key')
metadata = metadata.get_table_meta('olympic')

# Generate evaludation report

In [None]:
my_report = QualityReport()
my_report.generate(df_real_plot, df_fake_plot, metadata)
score = my_report.get_score()

# Visualization

In [None]:
get_column_plot(
    real_data=df_real_plot,
    synthetic_data=df_fake_plot,
    metadata=metadata,
    column_name='Height'
)

In [None]:
get_column_plot(
    real_data=df_real_plot,
    synthetic_data=df_fake_plot,
    metadata=metadata,
    column_name='Sex'
)

In [None]:
get_column_plot(
    real_data=df_real_plot,
    synthetic_data=df_fake_plot,
    metadata=metadata,
    column_name='City'
)

In [None]:
get_column_plot(
    real_data=df_real_plot,
    synthetic_data=df_fake_plot,
    metadata=metadata,
    column_name='AOE'
)

In [None]:
my_report.get_visualization(property_name='Column Pair Trends')