In [17]:
import torch
import pandas as pd
import numpy as np
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from scipy.stats import wasserstein_distance
from sklearn.metrics import mutual_info_score
import warnings
from datetime import datetime

# Differential Privacy
class DifferentialPrivacy:
    def __init__(self, epsilon=1.0, delta=1e-5):
        self.epsilon = epsilon
        self.delta = delta

    def add_noise(self, tensor, sensitivity):
        sigma = np.sqrt(2 * np.log(1.25 / self.delta)) * sensitivity / self.epsilon
        noise = torch.normal(0, sigma, tensor.shape, device=tensor.device)
        return tensor + noise

# Dataset
class HealthcareDataset(Dataset):
    def __init__(self, patients, observations, medications=None, conditions=None, allergies=None):
        self.patients = patients
        self.observations = observations
        self.medications = medications
        self.conditions = conditions
        self.allergies = allergies
        self.encoding_mappings = {}  # Store mappings for reverse transformation
        self.data = self._merge_and_preprocess()
        self.tensor_data = torch.tensor(self.data.values, dtype=torch.float32)

    def _merge_and_preprocess(self):
        """Merge and preprocess data from all sources."""
        # Merge patients with observations on patient ID
        merged_data = pd.merge(
            self.patients,
            self.observations,
            left_on="ID",
            right_on="PATIENT",
            how="inner",
            suffixes=("_patients", "_observations")
        )

        # Optionally merge with medications
        if self.medications is not None:
            merged_data = pd.merge(
                merged_data,
                self.medications,
                on="PATIENT",
                how="left",
                suffixes=("", "_medications")
            )

        # Optionally merge with conditions
        if self.conditions is not None:
            merged_data = pd.merge(
                merged_data,
                self.conditions,
                on="PATIENT",
                how="left",
                suffixes=("", "_conditions")
            )

        # Optionally merge with allergies
        if self.allergies is not None:
            merged_data = pd.merge(
                merged_data,
                self.allergies,
                on="PATIENT",
                how="left",
                suffixes=("", "_allergies")
            )

        numeric_cols = merged_data.select_dtypes(include=[np.number]).columns
        categorical_cols = merged_data.select_dtypes(include=['object']).columns

        for col in numeric_cols:
            merged_data[col] = merged_data[col].fillna(merged_data[col].mean())
        for col in categorical_cols:
            merged_data[col] = merged_data[col].fillna(merged_data[col].mode()[0])

        # Encode categorical variables and store mappings
        for col in categorical_cols:
            dummies = pd.get_dummies(merged_data[col], prefix=col)
            self.encoding_mappings[col] = {i: category for i, category in enumerate(merged_data[col].unique())}
            merged_data = pd.concat([merged_data, dummies], axis=1)
            merged_data.drop(columns=[col], inplace=True)

        # Normalize numerical columns
        for col in merged_data.columns:
            if merged_data[col].std() != 0:
                merged_data[col] = (merged_data[col] - merged_data[col].mean()) / merged_data[col].std()
            else:
                merged_data[col] = 0

        return merged_data
    
    def reverse_one_hot(self, synthetic_df):
        """Reverse one-hot encoding to original categories."""
        reversed_data = synthetic_df.copy()
        for col, mapping in self.encoding_mappings.items():
            # Combine one-hot columns into original category
            one_hot_cols = [c for c in synthetic_df.columns if c.startswith(f"{col}_")]
            reversed_data[col] = synthetic_df[one_hot_cols].idxmax(axis=1).map(lambda x: x.split("_")[-1])
            reversed_data.drop(columns=one_hot_cols, inplace=True)
        return reversed_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.tensor_data[idx]

# TGAN Model
class HealthcareTGAN(nn.Module):
    def __init__(self, dataset, hidden_dim=128, device="cuda"):
        super(HealthcareTGAN, self).__init__()
        self.dataset = dataset
        self.hidden_dim = hidden_dim
        self.device = device
        input_dim = dataset.data.shape[1]
        self.generator = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, input_dim),
            nn.Tanh()
        ).to(device)
        self.discriminator = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 2, 1),
            nn.Sigmoid()
        ).to(device)
        self.g_optimizer = optim.Adam(self.generator.parameters(), lr=0.001)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.001)
        self.criterion = nn.BCELoss()

    def generate(self, num_samples):
        noise = torch.randn(num_samples, self.hidden_dim).to(self.device)
        with torch.no_grad():
            synthetic_data = self.generator(noise).cpu()
        return pd.DataFrame(synthetic_data.numpy(), columns=self.dataset.data.columns)

# Enhanced TGAN with Differential Privacy
class EnhancedHealthcareTGAN(HealthcareTGAN):
    def __init__(self, dataset, hidden_dim=128, device="cuda", epsilon=1.0):
        super().__init__(dataset, hidden_dim, device)
        self.dp = DifferentialPrivacy(epsilon=epsilon)

    def train(self, epochs=10, batch_size=32):
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        
        for epoch in range(epochs):
            total_d_loss = 0
            total_g_loss = 0
            batches = 0
            
            for batch in dataloader:
                batch = batch.to(self.device)
                batch_size = batch.size(0)
                noisy_batch = self.dp.add_noise(batch, sensitivity=1.0)

                # Train Discriminator
                self.d_optimizer.zero_grad()
                real_validity = self.discriminator(noisy_batch)
                
                # Generate fake data
                noise = torch.randn(batch_size, self.hidden_dim, device=self.device)
                fake_data = self.generator(noise).detach()  # Detach to avoid computing gradients for generator
                fake_validity = self.discriminator(fake_data)
                
                # Calculate discriminator loss
                d_loss = (self.criterion(real_validity, torch.ones_like(real_validity)) +
                         self.criterion(fake_validity, torch.zeros_like(fake_validity))) / 2
                d_loss.backward()
                self.d_optimizer.step()

                # Train Generator
                self.g_optimizer.zero_grad()
                noise = torch.randn(batch_size, self.hidden_dim, device=self.device)
                fake_data = self.generator(noise)
                fake_validity = self.discriminator(fake_data)
                g_loss = self.criterion(fake_validity, torch.ones_like(fake_validity))
                g_loss.backward()
                self.g_optimizer.step()

                # Track losses
                total_d_loss += d_loss.item()
                total_g_loss += g_loss.item()
                batches += 1

            # Print epoch statistics
            avg_d_loss = total_d_loss / batches
            avg_g_loss = total_g_loss / batches
            print(f"[Epoch {epoch+1}/{epochs}] [D loss: {avg_d_loss:.4f}] [G loss: {avg_g_loss:.4f}]")

    def generate(self, num_samples):
        self.generator.eval()  # Set to evaluation mode
        with torch.no_grad():
            noise = torch.randn(num_samples, self.hidden_dim, device=self.device)
            synthetic_data = self.generator(noise).cpu()
        self.generator.train()  # Set back to training mode
        return pd.DataFrame(synthetic_data.numpy(), columns=self.dataset.data.columns)

class HealthcareDataValidator:
    def __init__(self, original_dataset, synthetic_data):
        self.original = original_dataset
        self.synthetic = synthetic_data
        self.original_stats = None
        self.synthetic_stats = None
        
    def _identify_column_groups(self):
        """Identify columns that belong to the same one-hot encoded group"""
        column_groups = {}
        for col in self.synthetic.columns:
            # Split column name by underscore to identify the base feature
            parts = col.split('_')
            if len(parts) > 1:
                base_feature = parts[0]
                if base_feature not in column_groups:
                    column_groups[base_feature] = []
                column_groups[base_feature].append(col)
        return column_groups
    
    def reverse_transform(self):
        """Reverse the normalization and one-hot encoding to get interpretable data"""
        reversed_data = pd.DataFrame()
        column_groups = self._identify_column_groups()
        
        # Process each group of one-hot encoded columns
        processed_columns = set()
        for base_feature, columns in column_groups.items():
            if len(columns) > 1:  # This is a one-hot encoded group
                # Get the most likely category for each row
                group_data = self.synthetic[columns]
                max_indices = group_data.idxmax(axis=1)
                # Extract the original category from the column name
                reversed_data[base_feature] = max_indices.apply(lambda x: '_'.join(x.split('_')[1:]))
                processed_columns.update(columns)
        
        # Process remaining numerical columns
        for col in self.synthetic.columns:
            if col not in processed_columns:
                # Check if this is a standalone column (not part of a one-hot group)
                if not any(col in group for group in column_groups.values()):
                    reversed_data[col] = self.synthetic[col]
                    
                    # Reverse normalization if the column exists in original data
                    if col in self.original.data.columns:
                        orig_std = self.original.data[col].std()
                        orig_mean = self.original.data[col].mean()
                        if orig_std != 0:
                            reversed_data[col] = (reversed_data[col] * orig_std) + orig_mean
        
        return reversed_data
    
    def compute_basic_stats(self):
        """Compute basic statistical measures for both datasets"""
        # Only compute stats for numerical columns
        numerical_cols = self.synthetic.select_dtypes(include=[np.number]).columns
        
        self.original_stats = {
            'means': self.original.data[numerical_cols].mean(),
            'stds': self.original.data[numerical_cols].std(),
            'mins': self.original.data[numerical_cols].min(),
            'maxs': self.original.data[numerical_cols].max()
        }
        
        self.synthetic_stats = {
            'means': self.synthetic[numerical_cols].mean(),
            'stds': self.synthetic[numerical_cols].std(),
            'mins': self.synthetic[numerical_cols].min(),
            'maxs': self.synthetic[numerical_cols].max()
        }
        
        return self.original_stats, self.synthetic_stats
    
    def validate(self):
        """Run all validations and print a summary report"""
        print("=== Healthcare Data Validation Report ===\n")
        
        # Get reversed data first
        reversed_data = self.reverse_transform()
        print("\nSample of reversed synthetic data:")
        print(reversed_data.head())
        
        # Compute basic stats for numerical columns
        orig_stats, synth_stats = self.compute_basic_stats()
        print("\nNumerical Column Statistics:")
        for stat_type in ['means', 'stds']:
            print(f"\n{stat_type.capitalize()}:")
            print(f"{'Column':<30} {'Original':>12} {'Synthetic':>12} {'Diff %':>10}")
            print("-" * 64)
            for col in orig_stats[stat_type].index:
                orig_val = orig_stats[stat_type][col]
                synth_val = synth_stats[stat_type][col]
                diff_pct = abs((synth_val - orig_val) / orig_val * 100) if orig_val != 0 else 0
                print(f"{col:<30} {orig_val:>12.4f} {synth_val:>12.4f} {diff_pct:>10.2f}%")
        
        # Categorical column analysis
        categorical_cols = [col for col in reversed_data.columns 
                          if reversed_data[col].dtype == 'object']
        if categorical_cols:
            print("\nCategorical Column Value Distributions:")
            for col in categorical_cols:
                print(f"\n{col}:")
                value_counts = reversed_data[col].value_counts(normalize=True)
                print(value_counts.head())
        
        return reversed_data

# Example usage:
def generate_and_validate(original_dataset, num_samples=100):
    # Generate synthetic data
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tgan = EnhancedHealthcareTGAN(original_dataset, device=device, epsilon=1.0)
    tgan.train(epochs=10, batch_size=16)
    synthetic_data = tgan.generate(num_samples)
    
    # Validate and reverse transform
    validator = HealthcareDataValidator(original_dataset, synthetic_data)
    reversed_data = validator.validate()
    
    return reversed_data, synthetic_data
    
# Usage Example
patients = pd.read_csv("cleaned_patients.csv")
observations = pd.read_csv("cleaned_observations.csv")
medications = pd.read_csv("cleaned_medications.csv")
conditions = pd.read_csv("cleaned_conditions.csv")
allergies = pd.read_csv("cleaned_allergies.csv")

data = HealthcareDataset(patients, observations, medications, conditions, allergies)
reversed_data, synthetic_data = generate_and_validate(data, num_samples=100)

# print("Data shape:", data.data.shape)
# print("Data types:\n", data.data.dtypes)
# print("Sample of preprocessed data:\n", data.data.head())
# print("Any null values:", data.data.isnull().sum().sum())
# device = "cuda" if torch.cuda.is_available() else "cpu"
# tgan = EnhancedHealthcareTGAN(data, device=device, epsilon=1.0)
# tgan.train(epochs=10, batch_size=16)
# synthetic_data = tgan.generate(100)
# print(synthetic_data)
# validator = HealthcareDataValidator(data, synthetic_data)
# reversed_data = validator.reverse_transform()

# print("\nSample of reversed synthetic data:")
# print(reversed_data.head())

# # Run validation
# validator.validate()

# # Show distributions
# validator.column_distributions()


[Epoch 1/10] [D loss: 0.9197] [G loss: 0.7146]
[Epoch 2/10] [D loss: 0.7481] [G loss: 0.6849]
[Epoch 3/10] [D loss: 0.6431] [G loss: 0.6649]
[Epoch 4/10] [D loss: 0.6524] [G loss: 0.6354]
[Epoch 5/10] [D loss: 0.7938] [G loss: 0.5918]
[Epoch 6/10] [D loss: 0.6121] [G loss: 0.5759]
[Epoch 7/10] [D loss: 0.5531] [G loss: 0.5398]
[Epoch 8/10] [D loss: 0.6234] [G loss: 0.5235]
[Epoch 9/10] [D loss: 0.6061] [G loss: 0.5169]
[Epoch 10/10] [D loss: 0.6435] [G loss: 0.5108]
=== Healthcare Data Validation Report ===


Sample of reversed synthetic data:
         CODE                                        ENCOUNTER  \
0  conditions   allergies_22bd855f-3897-4f05-a526-67ce6d05b755   
1  conditions  conditions_22bd855f-3897-4f05-a526-67ce6d05b755   
2  conditions   allergies_22bd855f-3897-4f05-a526-67ce6d05b755   
3  conditions   allergies_22bd855f-3897-4f05-a526-67ce6d05b755   
4  conditions   allergies_22bd855f-3897-4f05-a526-67ce6d05b755   

                                         DESCRIPTION 

In [23]:
data.data.head()

Unnamed: 0,CODE_medications,REASONCODE,CODE_conditions,CODE_allergies,ID_d3ae6cdf-6637-4386-98d3-217933da5606,BIRTHDATE_1945-04-27,DEATHDATE_1958-04-18,SSN_999-99-3396,DRIVERS_Unknown,PASSPORT_Unknown,...,START_conditions_1958-04-18,STOP_conditions_1970-01-01,ENCOUNTER_conditions_22bd855f-3897-4f05-a526-67ce6d05b755,DESCRIPTION_conditions_Coronary Heart Disease,DESCRIPTION_conditions_History of myocardial infarction (situation),DESCRIPTION_conditions_Myocardial Infarction,START_allergies_1946-11-04,STOP_allergies_1970-01-01,ENCOUNTER_allergies_22bd855f-3897-4f05-a526-67ce6d05b755,DESCRIPTION_allergies_Allergy to nut
0,-0.666667,0.988648,-0.577973,0,0,0,0,0,0,0,...,-1.333333,0,0,1.333333,-0.666667,-0.666667,0,0,0,0
1,-0.666667,0.988648,-0.751587,0,0,0,0,0,0,0,...,0.666667,0,0,-0.666667,-0.666667,1.333333,0,0,0,0
2,-0.666667,0.988648,1.32956,0,0,0,0,0,0,0,...,0.666667,0,0,-0.666667,1.333333,-0.666667,0,0,0,0
3,1.333333,0.280443,-0.577973,0,0,0,0,0,0,0,...,-1.333333,0,0,1.333333,-0.666667,-0.666667,0,0,0,0
4,1.333333,0.280443,-0.751587,0,0,0,0,0,0,0,...,0.666667,0,0,-0.666667,-0.666667,1.333333,0,0,0,0


In [18]:
reversed_data.head()

Unnamed: 0,CODE,ENCOUNTER,DESCRIPTION,START,STOP,REASONDESCRIPTION,REASONCODE
0,conditions,allergies_22bd855f-3897-4f05-a526-67ce6d05b755,medications_Nitroglycerin 0.4 MG/ACTUAT [Nitro...,conditions_1950-10-08,1970-01-01,Streptococcal sore throat (disorder),-0.620452
1,conditions,conditions_22bd855f-3897-4f05-a526-67ce6d05b755,Myocardial Infarction,conditions_1950-10-08,1970-01-01,Streptococcal sore throat (disorder),0.158414
2,conditions,allergies_22bd855f-3897-4f05-a526-67ce6d05b755,medications_Nitroglycerin 0.4 MG/ACTUAT [Nitro...,conditions_1950-10-08,1970-01-01,Streptococcal sore throat (disorder),-0.064614
3,conditions,allergies_22bd855f-3897-4f05-a526-67ce6d05b755,medications_Nitroglycerin 0.4 MG/ACTUAT [Nitro...,conditions_1950-10-08,1970-01-01,Streptococcal sore throat (disorder),-0.133546
4,conditions,allergies_22bd855f-3897-4f05-a526-67ce6d05b755,Myocardial Infarction,conditions_1950-10-08,1970-01-01,Streptococcal sore throat (disorder),-0.353891


In [19]:
synthetic_data.head()

Unnamed: 0,CODE_medications,REASONCODE,CODE_conditions,CODE_allergies,ID_d3ae6cdf-6637-4386-98d3-217933da5606,BIRTHDATE_1945-04-27,DEATHDATE_1958-04-18,SSN_999-99-3396,DRIVERS_Unknown,PASSPORT_Unknown,...,START_conditions_1958-04-18,STOP_conditions_1970-01-01,ENCOUNTER_conditions_22bd855f-3897-4f05-a526-67ce6d05b755,DESCRIPTION_conditions_Coronary Heart Disease,DESCRIPTION_conditions_History of myocardial infarction (situation),DESCRIPTION_conditions_Myocardial Infarction,START_allergies_1946-11-04,STOP_allergies_1970-01-01,ENCOUNTER_allergies_22bd855f-3897-4f05-a526-67ce6d05b755,DESCRIPTION_allergies_Allergy to nut
0,0.433044,-0.620452,0.698137,-0.612222,-0.501839,0.556021,-0.563152,-0.587445,0.606779,-0.209792,...,-0.764753,-0.513152,0.6484,-0.518882,-0.523514,0.667271,-0.17651,0.23706,0.717101,0.38398
1,0.374203,0.158414,0.647591,-0.149268,-0.4262,0.507329,-0.210579,-0.513919,0.546388,-0.20899,...,-0.622305,-0.587035,0.367713,-0.370619,-0.538767,0.515501,0.107849,0.216738,0.197382,0.383246
2,0.011909,-0.064614,0.766207,-0.765667,-0.411061,0.632425,-0.348173,-0.413556,0.680989,-0.273299,...,-0.642442,-0.584715,0.482374,-0.410582,-0.555542,0.332694,-0.041504,0.396341,0.530132,0.21054
3,0.162046,-0.133546,0.538093,-0.616856,-0.556826,0.074198,-0.326798,-0.378714,0.312878,-0.371086,...,-0.244394,-0.611244,0.101555,-0.629571,-0.325339,0.377207,0.20247,0.222506,0.247876,0.331277
4,0.366202,-0.353891,0.753415,-0.500958,-0.226927,0.343386,-0.166499,-0.578429,0.366865,-0.364635,...,-0.350285,-0.40074,0.017225,-0.598301,-0.365458,0.382079,-0.0205,0.348594,0.619125,0.174156
