In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder

import warnings
warnings.filterwarnings("ignore")

torch.set_default_dtype(torch.float32)

In [2]:
class BrainCancerDataset(Dataset):
    def __init__(self, csv_file):
        # Load the data
        self.data = pd.read_csv(csv_file)
        
        # Drop rows with missing OS and the ID column
        self.data['OS'] = pd.to_numeric(self.data['OS'], errors='coerce')
        self.data = self.data.dropna(subset=['OS'])
        self.data = self.data.drop('ID', axis=1)
        
        self.column_order = self.data.columns
        
        # Separate columns into numeric and categorical
        #self.feature_columns = [col for col in self.data.columns if col not in ['status', 'OS']]
        self.feature_columns = [col for col in self.data.columns if col not in ['OS']]
        numeric_cols = self.data[self.feature_columns].select_dtypes(include=['number']).columns.tolist()
        categorical_cols = self.data[self.feature_columns].select_dtypes(include=['object']).columns.tolist()
        
        self.feature_columns = numeric_cols
        self.categorical_columns = categorical_cols
        
        # Initialize scalers and encoders as attributes
        self.scaler = MinMaxScaler()
        self.encoder = OneHotEncoder(sparse_output=False)
        
        # Normalize numeric features using MinMaxScaler
        self.features = pd.DataFrame(
            self.scaler.fit_transform(self.data[numeric_cols]),
            columns=numeric_cols
        ).astype(np.float32)
        
        # Normalize the OS column using MinMaxScaler
        self.targets = pd.Series(
            self.scaler.fit_transform(self.data[['OS']]).flatten(),  # Normalize OS using scaler
            name='OS'
        ).astype(np.float32)
        
        # One-hot encode categorical features
        if categorical_cols:
            self.categorical_features = self.encoder.fit_transform(self.data[categorical_cols])
        else:
            self.categorical_features = np.zeros((len(self.data), 0))  # Handle cases with no categorical features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        # Numeric features
        features = self.features.iloc[idx].to_numpy(dtype=np.float32)
        
        # Categorical features (one-hot encoded)
        categorical_features = self.categorical_features[idx]
        
        # Target (normalized OS)
        target = self.targets.iloc[idx]
        
        return features, categorical_features, target


In [3]:
class Generator(nn.Module):
    def __init__(self, input_dim, feature_dim, categorical_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim + categorical_dim + 1, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim),
            nn.Tanh()
        )

    def forward(self, noise, condition, categorical_features):
        condition = condition.view(-1, 1)  # Ensure condition has two dimensions
        x = torch.cat((noise, categorical_features, condition), dim=1).float()
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, feature_dim, categorical_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(feature_dim + categorical_dim + 1, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, features, condition, categorical_features):
        condition = condition.view(-1, 1)  # Ensure condition has two dimensions
        x = torch.cat((features, categorical_features, condition), dim=1).float()
        return self.model(x)

In [4]:
def train_cgan(generator, discriminator, dataloader, gtv_index, num_epochs=5000, lr=0.0002):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=lr)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)

    for epoch in range(num_epochs):
        for features, categorical_features, os_value in dataloader:
            batch_size = features.size(0)
            features, categorical_features, os_value = (
                features.to(device).float(),
                categorical_features.to(device).float(),
                os_value.to(device).float()
            )

            # Generate noise and fake features
            noise = torch.randn(batch_size, 10).to(device)
            fake_features = generator(noise, os_value.unsqueeze(1), categorical_features)

            # Labels for real and fake data
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # -----------------------
            # Update Discriminator
            # -----------------------
            optimizer_d.zero_grad()
            real_output = discriminator(features, os_value.unsqueeze(1), categorical_features)
            fake_output = discriminator(fake_features.detach(), os_value.unsqueeze(1), categorical_features)

            # Validity penalty for invalid GTV values in fake features
            validity_penalty_d = torch.mean(torch.relu(-fake_features[:, gtv_index]))

            real_loss = criterion(real_output, real_labels)
            fake_loss = criterion(fake_output, fake_labels)

            # Total discriminator loss
            d_loss = real_loss + fake_loss + validity_penalty_d
            d_loss.backward()  # Backpropagation for discriminator
            optimizer_d.step()

            # -----------------------
            # Update Generator
            # -----------------------
            optimizer_g.zero_grad()
            fake_features = generator(noise, os_value.unsqueeze(1), categorical_features)  # Recreate fake features
            fake_output = discriminator(fake_features, os_value.unsqueeze(1), categorical_features)

            # Validity penalty for invalid GTV values in fake features
            validity_penalty_g = torch.mean(torch.relu(-fake_features[:, gtv_index]))

            g_loss = criterion(fake_output, real_labels) + validity_penalty_g
            g_loss.backward(retain_graph=False)  # Backpropagation for generator
            optimizer_g.step()

        # Logging progress
        if epoch % 500 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    print("Training finished.")

    # Save the generator and discriminator weights
    torch.save(generator.state_dict(), 'Weights/generator_weights.pth')
    torch.save(discriminator.state_dict(), 'Weights/discriminator_weights.pth')


In [5]:
def train_cgan_balanced(generator, discriminator, dataloader, num_epochs=5000, lr=0.0002, alpha=0.1, patience=50, eval_interval=100):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=lr)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)

    best_balance_score = float('inf')
    no_improve_epochs = 0

    for epoch in range(num_epochs):
        epoch_d_loss = 0.0
        epoch_g_loss = 0.0
        num_batches = 0

        for features, categorical_features, os_value in dataloader:
            batch_size = features.size(0)
            features, categorical_features, os_value = (
                features.to(device).float(),
                categorical_features.to(device).float(),
                os_value.to(device).float()
            )

            # Generate fake features
            noise = torch.randn(batch_size, 10).to(device)
            fake_features = generator(noise, os_value.unsqueeze(1), categorical_features)

            # Real and fake labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Discriminator Loss
            real_loss = criterion(discriminator(features, os_value.unsqueeze(1), categorical_features), real_labels)
            fake_loss = criterion(discriminator(fake_features.detach(), os_value.unsqueeze(1), categorical_features), fake_labels)

            # Validity penalty for invalid GTV values in fake features
            validity_penalty_d = torch.mean(torch.relu(-fake_features[:, 1]))

            d_loss = real_loss + fake_loss + validity_penalty_d

            optimizer_d.zero_grad()
            d_loss.backward(retain_graph=True)
            optimizer_d.step()

            # Validity penalty for invalid GTV values in fake features
            validity_penalty_g = torch.mean(torch.relu(-fake_features[:, 1]))

            # Generator Loss
            g_loss = criterion(discriminator(fake_features, os_value.unsqueeze(1), categorical_features), real_labels) + validity_penalty_g

            optimizer_g.zero_grad()
            g_loss.backward() 
            optimizer_g.step()

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            num_batches += 1

        # Calculate average losses
        avg_d_loss = epoch_d_loss / num_batches
        avg_g_loss = epoch_g_loss / num_batches

        # Calculate balance score (difference between generator and discriminator loss)
        balance_score = abs(avg_d_loss - avg_g_loss)

        if epoch % 500 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f} | Balance Score: {balance_score:.4f}")

        # Save models if balance improves
        if balance_score < best_balance_score:
            best_balance_score = balance_score
            no_improve_epochs = 0  # Reset patience counter
            torch.save(generator.state_dict(), 'Weights/generator_weights.pth')
            torch.save(discriminator.state_dict(), 'Weights/discriminator_weights.pth')
        else:
            no_improve_epochs += 1

        # Early stopping condition
        if no_improve_epochs >= patience:
            print(f"Early stopping at epoch {epoch}. Best Balance Score: {best_balance_score:.4f}")
            break

    print("Training finished.")

In [6]:
def validate_correlation(generator, dataset):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)

    original_features = []
    generated_features = []
    os_values = []

    for features, categorical_features, os_value in dataset:
        features, categorical_features, os_value = features.clone().detach().to(device), categorical_features.clone().detach().to(device), os_value.clone().detach().to(device)
        noise = torch.randn(features.size(0), 10).to(device)

        with torch.no_grad():
            fake_features = generator(noise, os_value.unsqueeze(1), categorical_features)

        generated_features.append(fake_features.cpu().numpy())
        original_features.append(features.cpu().numpy())
        os_values.append(os_value.cpu().numpy())

    original_features = np.concatenate(original_features)
    generated_features = np.concatenate(generated_features)
    os_values = np.concatenate(os_values)

    correlation_real = np.corrcoef(original_features[:, -1], os_values)[0, 1]
    correlation_fake = np.corrcoef(generated_features[:, -1], os_values)[0, 1]

    print(f"Real Data Correlation with OS: {correlation_real:.4f}")
    print(f"Synthetic Data Correlation with OS: {correlation_fake:.4f}")


In [7]:
class DataRecovery:
    def __init__(self, generator, dataset, encoder, scaler):
        self.generator = generator
        self.dataset = dataset
        self.encoder = encoder  # Fitted OneHotEncoder instance
        self.scaler = scaler    # Fitted MinMaxScaler instance
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.generator.to(self.device)

    def custom_round(self,number):

        base = (number // 10) * 10  # Nearest lower multiple of 10
        remainder = number % 10

        # Check the specific cases for midpoints
        if remainder == 1:  # e.g., 11
            return base  # Round down
        elif remainder == 6:  # e.g., 26
            return base + 10  # Round up
        else:
            # Standard rounding
            return base + (10 if remainder >= 5 else 0)


    def generate_rows(self, num_rows):
            self.generator.eval()
            all_generated = []

            for _ in range(num_rows):
                idx = np.random.randint(len(self.dataset))
                features, categorical_features, os_value = self.dataset[idx]

                features = torch.tensor(features, dtype=torch.float32).to(self.device)
                categorical_features = torch.tensor(categorical_features, dtype=torch.float32).to(self.device)
                os_value = torch.tensor(os_value, dtype=torch.float32).to(self.device)

                noise = torch.randn(1, 10).to(self.device)

                with torch.no_grad():
                    generated_features = self.generator(noise, os_value.view(1, -1), categorical_features.unsqueeze(0))

                generated_features_np = generated_features.cpu().numpy().flatten()
                categorical_features_np = categorical_features.cpu().numpy().reshape(1, -1)
                os_value_np = np.array([[os_value.cpu().numpy()]])

                # Reverse one-hot encoding
                original_categorical = self.encoder.inverse_transform(categorical_features_np)

                # Denormalize 'OS'
                original_os = self.scaler.inverse_transform(os_value_np)

                # Denormalize ['KI','GTV']
                status = generated_features_np[2]
                original_generated_features = self.scaler.inverse_transform(generated_features_np.reshape(1, -1))
                original_generated_features[0,2] = np.round(status,0)

                original_generated_features[0,0] = self.custom_round(original_generated_features[0,0])

                all_generated.append(
                    np.concatenate((
                        original_generated_features.flatten(),
                        original_categorical.flatten(),
                        original_os.flatten()
                    ))
                )

            self.generator.train()
            return np.array(all_generated)

    def save_generated(self, num_rows, column_order, filename="GeneratedData/Generated_Data_BrainCancer.csv"):
        generated_data = self.generate_rows(num_rows)

        if hasattr(self.dataset, "feature_columns") and hasattr(self.dataset, "categorical_columns"):
            columns = self.dataset.feature_columns + self.dataset.categorical_columns + ['OS']
        else:
            raise AttributeError("Dataset must have 'feature_columns' and 'categorical_columns' attributes.")

        if generated_data.shape[1] != len(columns):
            raise ValueError(
                f"Mismatch between data shape ({generated_data.shape[1]} columns) and column names ({len(columns)})."
            )

        df = pd.DataFrame(generated_data, columns=columns)
        df = df[column_order]
        df.to_csv(filename, index=False)
        print(f"Generated data saved to {filename}")

    def view_generated(self, num_rows):
        generated_data = self.generate_rows(num_rows)

        if hasattr(self.dataset, "feature_columns") and hasattr(self.dataset, "categorical_columns"):
            columns = self.dataset.feature_columns + self.dataset.categorical_columns + ['OS']
        else:
            raise AttributeError("Dataset must have 'feature_columns' and 'categorical_columns' attributes.")

        if generated_data.shape[1] != len(columns):
            raise ValueError(
                f"Mismatch between data shape ({generated_data.shape[1]} columns) and column names ({len(columns)})."
            )

        df = pd.DataFrame(generated_data, columns=columns)
        return df


In [8]:
csv_file = "Data/BrainCancer.csv"
dataset = BrainCancerDataset(csv_file)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

feature_dim = len(dataset[0][0])
categorical_dim = dataset.categorical_features.shape[1]

generator = Generator(input_dim=10, feature_dim=feature_dim, categorical_dim=categorical_dim)
discriminator = Discriminator(feature_dim=feature_dim, categorical_dim=categorical_dim)

In [9]:
# Train the model
#train_cgan(generator, discriminator, dataloader, gtv_index=1, num_epochs=10000)
train_cgan_balanced(generator, discriminator, dataloader,  num_epochs=100000, patience=10000)

Epoch [0/100000] | D Loss: 1.3863 | G Loss: 0.6713 | Balance Score: 0.7150
Epoch [500/100000] | D Loss: 1.3130 | G Loss: 0.7820 | Balance Score: 0.5310
Epoch [1000/100000] | D Loss: 1.1894 | G Loss: 0.9201 | Balance Score: 0.2693
Epoch [1500/100000] | D Loss: 1.4025 | G Loss: 0.7148 | Balance Score: 0.6877
Epoch [2000/100000] | D Loss: 1.4442 | G Loss: 0.7261 | Balance Score: 0.7180
Epoch [2500/100000] | D Loss: 1.2439 | G Loss: 0.8496 | Balance Score: 0.3943
Epoch [3000/100000] | D Loss: 1.2063 | G Loss: 0.9288 | Balance Score: 0.2775
Epoch [3500/100000] | D Loss: 1.2484 | G Loss: 0.8592 | Balance Score: 0.3892
Epoch [4000/100000] | D Loss: 1.1949 | G Loss: 0.9146 | Balance Score: 0.2802
Epoch [4500/100000] | D Loss: 1.1883 | G Loss: 0.7921 | Balance Score: 0.3962
Epoch [5000/100000] | D Loss: 1.2064 | G Loss: 0.9318 | Balance Score: 0.2747
Epoch [5500/100000] | D Loss: 1.1742 | G Loss: 0.9125 | Balance Score: 0.2617
Epoch [6000/100000] | D Loss: 1.1591 | G Loss: 0.8844 | Balance Scor

In [11]:
 # Load the saved weights (optional, if you want to reuse the models)
generator.load_state_dict(torch.load('best_generator_weights.pth'))
discriminator.load_state_dict(torch.load('best_discriminator_weights.pth'))

# Validate the model
validate_correlation(generator, dataloader)

Real Data Correlation with OS: -0.4022
Synthetic Data Correlation with OS: -0.4917


In [13]:
# Data recovery
data_recovery = DataRecovery(generator, dataset, dataset.encoder, dataset.scaler)

# Generate synthetic rows
num_rows = 50000
generated_rows = data_recovery.generate_rows(num_rows)
#print("Generated Rows:\n", generated_rows)

# Save synthetic rows to a CSV file
data_recovery.save_generated(num_rows=num_rows, column_order=dataset.column_order , filename="GeneratedData/Generated_Data_BrainCancer.csv")

# View synthetic rows as a pandas DataFrame
df = data_recovery.view_generated(num_rows=num_rows)
print("Generated DataFrame:\n", df.head())


Generated data saved to GeneratedData/generated_data.csv
Generated DataFrame:
      KI        GTV status     Sex   Diagnosis        Location  \
0  60.0  48.656315    1.0    male  meningioma  supratentorial   
1  60.0  23.922226    0.0    male   lg_glioma  supratentorial   
2  60.0  13.117501    0.0    male  meningioma  supratentorial   
3  50.0  18.215059    1.0  female   lg_glioma  supratentorial   
4  80.0  61.237827    0.0    male       other  supratentorial   

  Stereotactic methods         OS  
0                  srt  31.147541  
1                  srt  35.934425  
2                  srt  67.377045  
3                  srt   47.80328  
4                  srt  19.737705  
