## Setup

In [None]:
# Utils
import os
from dateutil.relativedelta import relativedelta

# Data
import numpy as np
import pandas as pd

# Viz
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(context='paper', style='ticks')

from utils import load_json, save_json

# --------------------------
# Helper Functions
# --------------------------

def round_to_first_of_next_month(dt):
    """
    Round a datetime.date or datetime.datetime to the first day
    of the *following* month at midnight (returns pandas.Timestamp).
    """
    year = dt.year
    month = dt.month + 1 if dt.month < 12 else 1
    year = year if month != 1 else year + 1
    return pd.to_datetime(f'{year}-{month:02d}-01')


def find_minimum_date_for_endpoint(df, endpoint, trend, ltv, htv, min_entries=100):
    """
    Returns the first date (after at least `min_entries` records)
    where the endpoint measurement satisfies the optimum condition.

    Parameters:
        df (pd.DataFrame): Contains 'SMILES', 'DATE', and endpoint columns.
        endpoint (str): Endpoint column name.
        trend (str): 'H', 'L', or 'V' for high, low, or value range optimal.
        ltv (float): Lower threshold value.
        htv (float): Higher threshold value.
        min_entries (int): Number of entries before checking condition.

    Returns:
        pd.Timestamp or None
    """
    temp_df = df[['SMILES', 'DATE', endpoint]].copy()
    temp_df = temp_df.sort_values('DATE').drop_duplicates('SMILES').dropna().reset_index(drop=True)

    for i, row in temp_df.iterrows():
        if i >= min_entries:
            value = row[endpoint]
            if (
                (trend == 'H' and value >= ltv) or
                (trend == 'L' and value <= htv) or
                (trend == 'V' and ltv <= value <= htv)
            ):
                return row['DATE']
    return None


def get_bs(nb_available_to_select, batch_size=0.5):
    """
    Determines how many samples to select based on batch size setting.

    Parameters:
        nb_available_to_select (int): Total number of available items.
        batch_size (float or int): Fraction (0<batch_size<1) or absolute value (>1).

    Returns:
        int: Number of samples to select.
    """
    if 0 <= batch_size < 1:
        return max(1, round(batch_size * nb_available_to_select)) if nb_available_to_select > 0 else 0
    elif batch_size >= 1:
        return min(nb_available_to_select, int(batch_size))
    else:
        raise ValueError("batch_size must be >= 0")

## Code

### Data loading

In [None]:
# Determine the dataset name from the current working directory
DATASET = os.getcwd().split('/')[-1]

# Define base path for dataset
data_path = f'../../data/{DATASET}'

# Load the main and aggregated data
data_df = pd.read_csv(f'{data_path}/data.csv').sort_values(by='DATE').reset_index(drop=True)
df = pd.read_csv(f'{data_path}/data_aggregated.csv').sort_values(by='DATE').reset_index(drop=True)

# Load blueprint data
bp_df = pd.read_csv(f'{data_path}/blueprint.csv')

# Convert 'DATE' column to datetime for proper time handling
data_df['DATE'] = pd.to_datetime(data_df['DATE'])
df['DATE'] = pd.to_datetime(df['DATE'])

# Extract endpoints from blueprint
endpoints = bp_df['PROPERTIES'].values

# Load simulation configuration
common_config = load_json('../../data/common/datasets_config.json')

# Display config info if available
if DATASET in common_config:
    cfg = common_config[DATASET]
    print(f"{DATASET} already in the configuration file, with the following parameters :")
    print(f"    dataset      : {DATASET}")
    print(f"    initial_date : {cfg['initial_date']}")
    print(f"    final_date   : {cfg['final_date']}")
    print(f"    timestep     : {cfg['timestep']}")

# ⚠️⚠️⚠️ If manual override is needed to determine minimum date, put it in the next line under the YYYY-MM-DD format, else None ⚠️⚠️⚠️
MINIMUM_DATE = None

# Plot distribution of experimental assays over time
fig, ax = plt.subplots(figsize=(9, 3))
sns.histplot(data_df['DATE'], bins=50, ax=ax)
ax.set_title(f"{DATASET} | Distribution of Experimental Assays")


if MINIMUM_DATE is not None:
    # Convert MINIMUM_DATE to pandas datetime
    MINIMUM_DATE = pd.to_datetime(MINIMUM_DATE)
    ax.axvline(x=MINIMUM_DATE, color='red', linestyle='--', label=f'Minimum date : {MINIMUM_DATE.date()}')
    ax.legend() 
plt.tight_layout()

### Defining simulations parameters

In [None]:
# --------------------------
# Constants and Parameters
# --------------------------
TIMESTEP = 1   # in months
BATCH_SIZE = 0.5 # fraction or int

# --------------------------
# Compute Minimum Valid Dates for Each Endpoint
# --------------------------
minimum_dates = {}

for endpoint in endpoints:
    try:
        # Retrieve the properties for the endpoint from bp_df
        endpoint_row = bp_df.loc[bp_df['PROPERTIES'] == endpoint]
        
        if endpoint_row.empty:
            print(f"Warning: No properties found for endpoint '{endpoint}' in bp_df. Skipping.")
            continue
            
        props = endpoint_row[['TREND', 'LTV', 'HTV', 'WEIGHT']].squeeze()
        trend = props['TREND']
        lower_acceptable_value = props['LTV']
        higher_acceptable_value = props['HTV']
        
        # Use the helper function on the full data_df
        min_date = find_minimum_date_for_endpoint(
            data_df, endpoint, trend, lower_acceptable_value, higher_acceptable_value
        )
        
        if min_date is not None:
            minimum_dates[endpoint] = min_date
        else:
            print(f"Warning: Could not determine minimum date for endpoint '{endpoint}'")
            
    except Exception as e:
        print(f"Error processing endpoint '{endpoint}': {str(e)}")

# Create a DataFrame from the dictionary for inspection if needed
minimum_dates_df = (
    pd.DataFrame.from_dict(minimum_dates, orient='index', columns=['MINIMUM_DATE'])
    .reset_index()
    .rename(columns={'index': 'ENDPOINT'})  # Using uppercase for consistency
)

# --------------------------
# Determine Simulation Dates
# --------------------------
# Choose the main endpoint (the one with the most complete data)
main_endpoint = df[endpoints].count().idxmax()
print(f"Selected '{main_endpoint}' as the main endpoint based on data completeness")

if main_endpoint in minimum_dates:
    minimum_date_main = minimum_dates[main_endpoint]
    
    # Set the simulation starting date based on the main endpoint;
    # then round to the first day of the next month.
    starting_date = round_to_first_of_next_month(minimum_date_main)
    
    if MINIMUM_DATE is not None and starting_date < pd.to_datetime(MINIMUM_DATE):
        print(f"⚠️ Starting date {starting_date.date()} is set to the minimum date {MINIMUM_DATE.date()}")
        starting_date = pd.to_datetime(MINIMUM_DATE)
    else: 
        print(f"Simulation starting date set to: {starting_date}")
else:
    raise ValueError(f"Main endpoint '{main_endpoint}' does not have a minimum date")

# The simulation ends at the latest date in df.
ending_date = df['DATE'].max()
print(f"Simulation ending date set to: {ending_date.date()}")

# --------------------------
# Filter Endpoints by Pre-Simulation Data
# --------------------------
# Find endpoints with at least 100 molecules documented before starting_date
pre_simulation_data = data_df[data_df['DATE'] < starting_date]
endpoint_counts = pre_simulation_data[endpoints].count()
endpoints_with_sufficient_data = endpoint_counts[endpoint_counts >= 100].index
nb_endpoints_with_sufficient_data = len(endpoints_with_sufficient_data)

In [None]:
# --------------------------
# Determine Endpoints with Compounds in Optimum Range
# --------------------------
# Pre-compute filtered dataframe once (more efficient)
pre_simulation_df = data_df[data_df['DATE'] < starting_date].copy()

valid_endpoints = []
for endpoint in endpoints_with_sufficient_data:
    try:
        endpoint_row = bp_df.loc[bp_df['PROPERTIES'] == endpoint]
        
        if endpoint_row.empty:
            print(f"Warning: No properties found for endpoint '{endpoint}' in bp_df. Skipping.")
            continue
            
        props = endpoint_row[['TREND', 'LTV', 'HTV', 'WEIGHT']].squeeze()
        trend = props['TREND']
        lower_acceptable_value = props['LTV']  # Lower Acceptable Value
        higher_acceptable_value = props['HTV']  # Higher Acceptable Value
        
        # Check if at least one compound is in optimum range
        if find_minimum_date_for_endpoint(pre_simulation_df, endpoint, trend, 
                                          lower_acceptable_value, higher_acceptable_value) is not None:
            valid_endpoints.append(endpoint)
    except Exception as e:
        print(f"Error processing endpoint '{endpoint}': {str(e)}")

nb_valid_endpoints = len(valid_endpoints)
print(f"Found {nb_valid_endpoints} endpoints with at least one compound in optimum range")

# --------------------------
# Monthly Molecule Counts After Simulation Start
# --------------------------
# Count unique molecules per month after simulation start
monthly_counts = (
    df[df['DATE'] >= starting_date][['SMILES', 'DATE']]
    .copy()
    .set_index('DATE')
    .resample('ME')  # Month End frequency
    .size()
)
print(f"Created monthly counts for {len(monthly_counts)} months after simulation start")

# --------------------------
# Compute Time Buckets and Prepare Simulation
# --------------------------
# Calculate number of iterations more directly
simulation_months = ((ending_date.year - starting_date.year) * 12 + 
                    ending_date.month - starting_date.month)
nb_iterations = (simulation_months // TIMESTEP) + (1 if simulation_months % TIMESTEP > 0 else 0)

# Filter data for simulation period
simulation_df = df[df['DATE'] >= starting_date].copy()
nb_initial = df.shape[0] - simulation_df.shape[0]

# Create month-based integer representation for bucketing
simulation_df['year_month_int'] = simulation_df['DATE'].dt.year * 12 + simulation_df['DATE'].dt.month

# Create time buckets based on TIMESTEP
simulation_df['time_bucket'] = (simulation_df['year_month_int'] - 
                              (starting_date.year * 12 + starting_date.month)) // TIMESTEP

# Count compounds per time bucket
timestep_counts = simulation_df.groupby('time_bucket').size()

# Ensure all buckets exist (fill gaps with zeros)
full_buckets = range(0, nb_iterations)
timestep_counts = timestep_counts.reindex(full_buckets, fill_value=0)

# Create human-readable labels for time buckets
labels = []
start_month_num = starting_date.year * 12 + starting_date.month
for bucket in timestep_counts.index:
    month_num = start_month_num + (bucket * TIMESTEP)
    year = month_num // 12
    month = month_num % 12
    if month == 0:
        year -= 1
        month = 12
    labels.append(f"{year}-{month:02d}")

timestep_counts.index = labels

# Calculate batch sizes for each time period
simulation_bs = [get_bs(count, batch_size=BATCH_SIZE) for count in timestep_counts]

# Verify our calculations match
assert len(timestep_counts) == nb_iterations, f"Expected {nb_iterations} timesteps but got {len(timestep_counts)}"

# --------------------------
# Simulation: Tracking Compound Selection
# --------------------------
# Initialize tracking variables
df['iteration'] = 0  # Add column to track when compounds enter the simulation
already_selected = 0
already_selected_list = []
selected_at_iteration_list = []
iterations = []

# Perform simulation over time periods
current_date = starting_date
for iteration in range(1, nb_iterations + 1):
    iterations.append(iteration)
    next_date = current_date + relativedelta(months=TIMESTEP)
    
    # Tag compounds in current time window with iteration number
    mask = (df['DATE'] >= current_date) & (df['DATE'] < next_date)
    df.loc[mask, 'iteration'] = iteration
    
    # Calculate statistics
    available_since_start = len(df[(df['DATE'] >= starting_date) & (df['DATE'] < next_date)])
    available_for_selection = available_since_start - already_selected
    
    # Record selected compounds for this iteration
    selected = simulation_bs[iteration - 1]
    already_selected += selected
    
    # Track progress
    selected_at_iteration_list.append(selected)
    already_selected_list.append(already_selected)
    
    current_date = next_date

# Calculate final statistics
nb_final_selected = already_selected_list[-1] if already_selected_list else 0
nb_initially = df[df['DATE'] < starting_date]['SMILES'].nunique()
nb_total_smiles = df['SMILES'].nunique()

print(f"Initially available compounds: {nb_initially}")
print(f"Total unique compounds: {nb_total_smiles}")
print(f"Final selected compounds: {nb_final_selected}")


In [None]:
# --------------------------
# Summary Output
# --------------------------
try:
    # Helper function for safe date formatting
    def safe_date_format(date_obj):
        """Safely format a date object, handling None values."""
        if pd.isna(date_obj):
            return "Not available"
        return date_obj.date()

    # Create visual separators
    separator = "-" * 80
    
    # Project overview section
    print(separator)
    print("PROJECT OVERVIEW")
    print(separator)
    print(f"Number of molecules in project: {nb_total_smiles:,}")
    print(f"The main endpoint is: {main_endpoint}")
    print(f"The minimum date when the main endpoint is documented on at least 100 molecules: "
          f"{safe_date_format(minimum_date_main)}")

    # Simulation parameters
    print(f"\n{separator}")
    print("SIMULATION PARAMETERS")
    print(separator)
    print(f"Simulation starting date: {safe_date_format(starting_date)}")
    print(f"Simulation ending date: {safe_date_format(ending_date)}")
    print(f"Timestep: {TIMESTEP} month(s)")
    print(f"Batch size: {BATCH_SIZE}")

    # Pre-simulation statistics
    print(f"\n{separator}")
    print("PRE-SIMULATION STATISTICS")
    print(separator)
    print(f"Molecules available before simulation start: {nb_initially:,}")
    
    # Display endpoints with sufficient data
    print(f"Endpoints with ≥100 molecules documented: {nb_endpoints_with_sufficient_data}")
    if nb_endpoints_with_sufficient_data > 0:
        endpoint_list = sorted(list(endpoints_with_sufficient_data))
        if len(endpoint_list) <= 10:
            print(f"  List: {', '.join(endpoint_list)}")
        else:
            print(f"  First 10: {', '.join(endpoint_list[:10])}...")
    
    # Display endpoints with optimum range
    print(f"Endpoints with ≥100 molecules documented AND at least one in optimum range: "
          f"{nb_valid_endpoints}")
    if nb_valid_endpoints > 0:
        valid_endpoint_list = sorted(list(valid_endpoints))
        if len(valid_endpoint_list) <= 10:
            print(f"  List: {', '.join(valid_endpoint_list)}")
        else:
            print(f"  First 10: {', '.join(valid_endpoint_list[:10])}...")

    # Monthly statistics
    print(f"\n{separator}")
    print("MONTHLY MOLECULE STATISTICS")
    print(separator)
    
    if not monthly_counts.empty:
        print(f"Median molecules per month: {round(monthly_counts.median()):,}")
        print(f"Mean molecules per month: {round(monthly_counts.mean()):,}")
        print(f"Standard deviation: {round(monthly_counts.std()):,}")
    else:
        print("No monthly data available.")

    # Simulation results
    print(f"\n{separator}")
    print("SIMULATION RESULTS")
    print(separator)
    print(f"Number of iterations (TIMESTEP = {TIMESTEP} month(s)): {nb_iterations}")
    
    # Calculate exploration percentages
    percentage_total = round(100 * (nb_final_selected + nb_initially) / nb_total_smiles)
    percentage_initial = round(100 * nb_initially / nb_total_smiles)
    percentage_selected = round(100 * nb_final_selected / nb_total_smiles)
    
    print(f"Percentage of molecules explored with batch size {BATCH_SIZE}:")
    print(f"  Initial:  {percentage_initial}%")
    print(f"  Selected: {percentage_selected}%")
    print(f"  Total:    {percentage_total}%")
    
    print(separator)

except Exception as e:
    print(f"Error generating summary output: {str(e)}")
    import traceback
    traceback.print_exc()


In [None]:
# --------------------------
# Plotting Results
# --------------------------
fig, ax = plt.subplots(
    1, 2, 
    figsize=(15, 5)  # First ax is twice as wide as the second
)

sns.histplot(df['DATE'], bins=100, ax=ax[0], zorder=1)
ax[0].axvspan(starting_date, ending_date, color='green', alpha=0.33, label='Simulation period', zorder=0)

if 'ANN' in df.columns:
    ann_dates = []
    for i, ann in enumerate(df[~df['ANN'].isna()]['ANN'].unique()):
        ann_dates = df[df['ANN'] == ann]['DATE'].unique()

        for ann_date in ann_dates:
            ax[0].axvline(ann_date, linestyle='--', color=['red', 'green', 'blue'][i], label=f'ANN : {ann}', zorder=2)
ax[0].legend()
ax[0].set_title(f"{DATASET} | Project timeline")

sns.lineplot(df['iteration'].value_counts(sort=False).drop(0).cumsum(), linewidth=2, label='Cumsum compouds in the project (excluding the initial ones)', ax=ax[1])
sns.lineplot(x=iterations, y=already_selected_list, linewidth=2, label=f'Cumsum selected compounds in a simulation', ax=ax[1])
sns.lineplot(x=iterations, y=np.array(iterations) * 24, linestyle='--', linewidth=2, label=f'Theoretical maximum cumulative sum of selected compound', ax=ax[1])

ax[1].set_title(f"{DATASET} : Cumulative sum of compounds (TIMESTEP = {TIMESTEP}, BATCH SIZE = {BATCH_SIZE})");

plt.tight_layout()

In [None]:
print(f"🆕 Parameters :")
print(f"    initial_date : {str(starting_date).split(' ')[0]}")
print(f"    final_date   : {str(ending_date).split(' ')[0]}")
print(f"    timestep     : {TIMESTEP}")

parameters = {
    'initial_date': str(starting_date).split(' ')[0],
    'final_date': str(ending_date).split(' ')[0],
    'timestep': TIMESTEP
}

if DATASET in common_config.keys():

    old_parameters = common_config[DATASET]

    if old_parameters != parameters:
        print(f"⚠️ {DATASET} already in the configuration file, with different parameters :")
        print(f"    initial_date : {common_config[DATASET]['initial_date']}")
        print(f"    final_date   : {common_config[DATASET]['final_date']}")
        print(f"    timestep     : {common_config[DATASET]['timestep']}")

        print(f"\n⚠️ If you run the following cell, it will save the new parameters ! ")


In [None]:
# # Replace the config with newly defined parameters
# common_config[DATASET] = parameters
# save_json(common_config, f'../../data/common/datasets_config.json')