In [5]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn

from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

In [4]:
df = pd.read_csv('/content/drive/MyDrive/PROJECTS & WORKS/other_projects/Follow-up_Records.csv')
print(df.head())

   patient_id  visit_date  age_years  weight_kg   bmi  systolic_bp_mmHg  \
0  P-2025-001  2024-02-15         52       83.7  28.3               138   
1  P-2025-001  2024-03-15         52       83.4  28.2               147   
2  P-2025-001  2024-04-15         52       83.1  28.1               140   
3  P-2025-001  2024-05-15         52       83.0  28.1               136   
4  P-2025-001  2024-06-15         52       82.6  27.9               133   

   diastolic_bp_mmHg  heart_rate_bpm  body_temp_C  fasting_glucose_mg_dL  ...  \
0                 86              80         36.8                    137  ...   
1                 89              80         37.0                    140  ...   
2                 84              76         36.8                    122  ...   
3                 88              77         36.8                    112  ...   
4                 88              78         36.8                    101  ...   

   diet_quality_score_0_100  sleep_hours  exercise_sessions_pe

In [7]:
num_cols = df.select_dtypes(include=['int64',  'float64']).columns
cat_cols = df.select_dtypes(include=['object']).columns

encoder = OneHotEncoder(sparse_output=False)
cat_encoded = encoder.fit_transform(df[cat_cols])

scaler = MinMaxScaler(feature_range=(-1, 1))
num_scaled = scaler.fit_transform(df[num_cols])

# combining processed data
processed_data = np.hstack([cat_encoded, num_scaled])

In [9]:
data_dim = processed_data.shape[1] # total features=60
# print(data_dim)
latent_dim = 64 # size of random noise innput

# generator
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(latent_dim, 128),
        nn.ReLU(0.2),
        nn.Linear(128, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, data_dim),
        nn.Tanh() # output in range [-1, 1]
    )
  def forward(self, z):
    return self.model(z)

# discriminator
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(data_dim, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 128),
        nn.LeakyReLU(0.2),
        nn.Linear(128, 1),
        nn.Sigmoid() # probability of real/fake
    )
  def forward(self, x):
    return self.model(x)

### Discriminator: Learns to classify real vs fake correctly
### Generator: Learns to fool the Discriminator

In [15]:
# converting dataset to PyTorch tensors
real_data = torch.tensor(processed_data, dtype=torch.float32)
dataset = TensorDataset(real_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# initializing models
generator = Generator()
discriminator = Discriminator()

# loss function and optimizers
criterion = nn.BCELoss()

# optimizers
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# epochs
epochs = 2000

for epoch in range(epochs):
  for real_batch, in dataloader:
    batch_size = real_batch.size(0)

    # lables for real/fake data
    real_labels = torch.ones(batch_size, 1)
    fake_labels = torch.zeros(batch_size, 1)

    # training the discriminator
    z = torch.randn(batch_size, latent_dim)
    fake_data = generator(z)
    real_loss = criterion(discriminator(real_batch), real_labels)
    fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
    discriminator_loss = (real_loss + fake_loss) / 2

    discriminator_optimizer.zero_grad()
    discriminator_loss.backward()
    discriminator_optimizer.step()

    # training the genrator
    z = torch.randn(batch_size, latent_dim)
    fake_data = generator(z)
    generator_loss = criterion(discriminator(fake_data), real_labels)  # want fake data to be real

    generator_optimizer.zero_grad()
    generator_loss.backward()
    generator_optimizer.step()

if (epoch+1) % 200 == 0:
  print(f'Epoch [{epoch+1}/{epochs}], Discriminator Loss: {discriminator_loss.item():.4f}, Generator Loss: {generator_loss.item():.4f}')

Epoch [2000/2000], Discriminator Loss: 0.1006, Generator Loss: 2.8885


### Generating synthetic medical records


In [18]:
# generating new synthetic data
z = torch.randn(20, latent_dim) # 10 samples
synthetic_data_scaled = generator(z).detach().numpy()

# inverse transform
num_synthetic = scaler.inverse_transform(synthetic_data_scaled[:, :len(num_cols)])
cat_synthetic = encoder.inverse_transform(synthetic_data_scaled[:, len(num_cols):])

# converting to dataframe
synthetic_df = pd.DataFrame(num_synthetic, columns=num_cols)
synthetic_df[cat_cols] = cat_synthetic
print(synthetic_df)

    age_years  weight_kg        bmi  systolic_bp_mmHg  diastolic_bp_mmHg  \
0   52.967430  82.233566  27.822952        129.303986          80.543808   
1   52.986538  82.319733  27.794527        127.744698          80.082924   
2   52.992363  82.275078  27.787752        126.686508          79.308777   
3   52.986526  82.247513  27.792105        126.332832          79.790604   
4   52.971813  82.172165  27.822163        128.537018          80.676384   
5   52.998501  82.400047  27.802914        124.919685          78.439056   
6   52.980053  82.139427  27.819618        128.095505          80.922668   
7   52.988945  81.980652  27.796934        126.394012          81.526253   
8   52.983162  82.272873  27.813095        128.301056          80.134224   
9   52.994514  82.314636  27.778120        126.963585          80.154869   
10  52.999603  82.415115  27.782461        123.896835          78.145203   
11  52.985085  82.083366  27.790266        127.575897          81.062492   
12  52.96972