# DSDDPM: Dual-Scale Diffusion Probabilistic Model
Running the Tabular Data Generation model on Google Colab.

In [None]:
# 1. Clone the Repository (Force Clean Start)
# Removes existing folder to ensure we get the latest code fixes
import os
if os.path.exists('Tabular-Data-Generation'):
    !rm -rf Tabular-Data-Generation

!git clone https://github.com/MuniSurya18/Tabular-Data-Generation.git
%cd Tabular-Data-Generation

In [None]:
# 2. Install Dependencies
!pip install -r requirements.txt

In [None]:
# 3. Generate Dummy Data (for testing)
!python scripts/generate_synthetic.py

In [None]:
# 4. Train the Model
# Reduced epochs for demonstration. set epochs=100 for real training.
!python -m src.train --data data/dummy.csv --epochs 50 --batch_size 64

In [None]:
# 5. Generate Synthetic Data
!python -m src.sample --model checkpoints/model_final.pt --data data/dummy.csv --output generated_data.csv --num_samples 1000

In [None]:
# 6. Verify Results (Stats)
import pandas as pd
import os

if os.path.exists("generated_data.csv"):
    df = pd.read_csv("generated_data.csv")
    print("First 5 rows of generated data:")
    print(df.head())
    print("\nStatistical Summary:")
    print(df.describe())
else:
    print("Error: generated_data.csv not found. Check the generation step output for errors.")

In [None]:
# 7. Visualize Results
import matplotlib.pyplot as plt
import seaborn as sns

if os.path.exists("generated_data.csv"):
    # Set up the visual style
    sns.set(style="whitegrid")

    # Create a figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Synthetic Data Distribution', fontsize=16)

    # 1. Numerical: Age Distribution
    sns.histplot(data=df, x='Age', kde=True, ax=axes[0, 0], color='skyblue')
    axes[0, 0].set_title('Age Distribution (Numerical)')

    # 2. Numerical: Income Distribution
    sns.histplot(data=df, x='Income', kde=True, ax=axes[0, 1], color='orange')
    axes[0, 1].set_title('Income Distribution (Numerical)')

    # 3. Categorical: Gender Count
    sns.countplot(data=df, x='Gender', ax=axes[1, 0], palette='viridis')
    axes[1, 0].set_title('Gender Count (Categorical)')

    # 4. Categorical: Churn Count
    sns.countplot(data=df, x='Churn', ax=axes[1, 1], palette='pastel')
    axes[1, 1].set_title('Churn Count (Categorical)')

    plt.tight_layout()
    plt.show()
else:
    print("Data not found, cannot plot.")