<a href="https://colab.research.google.com/github/amazing-lucky/HAB_Detection/blob/main/Synthethic_HAB_GAN_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Define parameter ranges
param_ranges = {
    'Bloom_Index': (-1.0, 1.0),
    'Rolling_Chlorophyll_Anomaly': (-2.0, 5.0),
    'Rolling_SST_Anomaly': (-3.0, 3.0),
    'Surface_Chlorophyll': (0.0, 10.0),
    'Sea_Surface_Temperature': (20.0, 35.0),
    'Dissolved_Oxygen': (0.0, 12.0),
    'pH': (6.5, 9.5),
    'Total_Nitrogen': (0.0, 10.0),
    'Total_Phosphorus': (0.0, 1.0)
}

# Number of parameters
num_params = len(param_ranges)

# Define Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, num_params),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(num_params, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Create a small seed dataset with realistic correlations
def create_seed_dataset(n_samples=500):
    data = {}

    # Generate base random values
    for param, (min_val, max_val) in param_ranges.items():
        data[param] = np.random.uniform(min_val, max_val, n_samples)

    # Add correlations between variables
    data['pH'] += data['Surface_Chlorophyll'] * 0.05
    data['Dissolved_Oxygen'] -= data['Surface_Chlorophyll'] * 0.2
    data['Dissolved_Oxygen'] -= (data['Sea_Surface_Temperature'] - 20) * 0.1
    data['Surface_Chlorophyll'] += data['Total_Nitrogen'] * 0.2 + data['Total_Phosphorus'] * 1.5

    # Clip values to ensure they stay within ranges
    for param, (min_val, max_val) in param_ranges.items():
        data[param] = np.clip(data[param], min_val, max_val)

    return pd.DataFrame(data)

# Create seed dataset
seed_df = create_seed_dataset()

# Normalize data to [-1, 1] range for GAN training
def normalize_data(df):
    normalized_df = df.copy()
    for param, (min_val, max_val) in param_ranges.items():
        normalized_df[param] = 2 * (df[param] - min_val) / (max_val - min_val) - 1
    return normalized_df

# Denormalize data back to original ranges
def denormalize_data(df):
    denormalized_df = df.copy()
    for param, (min_val, max_val) in param_ranges.items():
        denormalized_df[param] = (df[param] + 1) * (max_val - min_val) / 2 + min_val
    return denormalized_df

# Normalize seed data
normalized_seed_df = normalize_data(seed_df)

# Convert to PyTorch tensors
real_data = torch.FloatTensor(normalized_seed_df.values)

# Create DataLoader
dataset = TensorDataset(real_data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCELoss()

# Training loop
num_epochs = 5000
print("Starting GAN training...")

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        batch_size = data[0].size(0)

        # Real data
        real_data_batch = data[0]
        real_labels = torch.ones(batch_size, 1)

        # Fake data
        z = torch.randn(batch_size, 100)
        fake_data_batch = generator(z)
        fake_labels = torch.zeros(batch_size, 1)

        # Train discriminator
        discriminator.zero_grad()

        # Real data loss
        real_outputs = discriminator(real_data_batch)
        d_loss_real = criterion(real_outputs, real_labels)

        # Fake data loss
        fake_outputs = discriminator(fake_data_batch.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)

        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # Train generator
        generator.zero_grad()
        fake_outputs = discriminator(fake_data_batch)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

    # Print progress
    if (epoch + 1) % 500 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")

print("GAN training complete!")

# Generate synthetic data
def generate_synthetic_data(num_samples=2000):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, 100)
        synthetic_data = generator(z).numpy()

    # Convert to DataFrame
    synthetic_df = pd.DataFrame(synthetic_data, columns=param_ranges.keys())

    # Denormalize to original ranges
    synthetic_df = denormalize_data(synthetic_df)

    # Apply correlations to ensure relationships are maintained
    synthetic_df['pH'] += synthetic_df['Surface_Chlorophyll'] * 0.05
    synthetic_df['Dissolved_Oxygen'] -= synthetic_df['Surface_Chlorophyll'] * 0.2
    synthetic_df['Dissolved_Oxygen'] -= (synthetic_df['Sea_Surface_Temperature'] - 20) * 0.1
    synthetic_df['Surface_Chlorophyll'] += synthetic_df['Total_Nitrogen'] * 0.2 + synthetic_df['Total_Phosphorus'] * 1.5

    # Clip values to ensure they stay within ranges
    for param, (min_val, max_val) in param_ranges.items():
        synthetic_df[param] = np.clip(synthetic_df[param], min_val, max_val)

    # Add HAB_Present column based on thresholds
    synthetic_df['HAB_Present'] = ((synthetic_df['Bloom_Index'] > 0.2) &
                                  (synthetic_df['Rolling_Chlorophyll_Anomaly'] > 1.0) &
                                  (synthetic_df['Surface_Chlorophyll'] > 3.0)).astype(int)

    return synthetic_df

# Generate 2000 synthetic samples
synthetic_data = generate_synthetic_data(2000)

# Display summary statistics
print("\nSynthetic Data Summary:")
print(synthetic_data.describe())

# Display HAB distribution
hab_count = synthetic_data['HAB_Present'].value_counts()
print(f"\nHAB Distribution:\nNo HAB: {hab_count.get(0, 0)}\nHAB Present: {hab_count.get(1, 0)}")

# Save to CSV
synthetic_data.to_csv('synthetic_hab_data_gan.csv', index=False)
print("\nSynthetic data saved to 'synthetic_hab_data_gan.csv'")


Starting GAN training...
Epoch [500/5000], d_loss: 0.9936, g_loss: 1.5032
Epoch [1000/5000], d_loss: 0.2025, g_loss: 4.0647
Epoch [1500/5000], d_loss: 0.0996, g_loss: 6.7787
Epoch [2000/5000], d_loss: 0.1536, g_loss: 9.1018
Epoch [2500/5000], d_loss: 0.0045, g_loss: 8.9463
Epoch [3000/5000], d_loss: 0.2497, g_loss: 5.0255
Epoch [3500/5000], d_loss: 0.0529, g_loss: 4.1512
Epoch [4000/5000], d_loss: 0.0657, g_loss: 4.3284
Epoch [4500/5000], d_loss: 0.1239, g_loss: 4.1466
Epoch [5000/5000], d_loss: 0.4182, g_loss: 3.4691
GAN training complete!

Synthetic Data Summary:
       Bloom_Index  Rolling_Chlorophyll_Anomaly  Rolling_SST_Anomaly  \
count  2000.000000                  2000.000000          2000.000000   
mean     -0.067774                     1.456813            -0.722058   
std       0.392716                     1.157638             1.178126   
min      -0.876400                    -1.771426            -2.919351   
25%      -0.288655                     0.482426            -1.433621