In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torchsummary import summary
import torch.nn.functional as F

import gpytorch

#from cuml.ensemble import RandomForestRegressor as cuRF  # not working on windows

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import BayesianRidge
#from sklearn.metrics import r2_score  # Use a torch version instead.

from scipy.stats import norm
from dataloader import load_gdsc, prepare_features, split_data
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import random

import os

import pandas as pd

import warnings
from sklearn.exceptions import ConvergenceWarning

In [2]:
config = {
    'Name':                'Yucheng Shao',

    'random_seed':         16,
    'PCA_component':       70,        # in the inspection file, aim for 90% variance

    'surrogate':           'bayes_linear',      # 'gp', 'rf', 'bayes_linear', 'mc_dropout'
                                      # for Gaussian Process, Random Forest, and Bayesian Ridge
    
    'training_iter':       20,        # GP cannot run efficient enough, train for some iterations
    'max_train_size':      1024,      # GP cannot run efficient enough, train from a subset instead
    
    'subsample_pool_size': 4096,      # GP cannot run efficient enough, draw from a subpool instead

    'n_estimators':        100,       # RF if not efficient enough, lower this value
    
    'mc_droupout_T':       20,        # number of dropout models used, reduce for efficiency
    
    'acquisition':         'metropolis',    # 'passive', 'greedy', 'mcmc', 'thompson', 'metropolis'
    'initial_ratio':       0.01,

    'mcmc_steps':          1000,      # number of steps to average for mcmc
    
    'metropolis_steps':    10000,     # number of steps to do metropolis select

    'val_ratio':           0.2,

    'batch_size':          2048,
    'epochs':              10,
    'lr':                  0.0012,
    'weight_decay':        1e-4,
    'dropout':             0.4
}

In [3]:
config['acquisition'] = config['acquisition'].lower()
valid_acq = {'passive', 'greedy', 'mcmc', 'thompson', 'metropolis'}
if config['acquisition'] not in valid_acq:
    raise ValueError(f"Unknown acquisition type '{config['acquisition']}'. Choose from {valid_acq}")

config['surrogate'] = config['surrogate'].lower()
valid_sur = {'gp', 'rf', 'bayes_linear', 'mc_dropout'}
if config['surrogate'] not in valid_sur:
    raise ValueError(f"Unknown surrogate type '{config['surrogate']}'. Choose from {valid_sur}")


assert config['subsample_pool_size'] > config['batch_size']

In [4]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set seed once at the beginning
set_seed(config['random_seed'])

In [5]:
excluded_columns = ['LN_IC50', 'AUC', 'Z_SCORE', 'DRUG_ID', 'COSMIC_ID', 'DRUG_NAME', 'CELL_LINE_NAME']
df = load_gdsc(excluded_columns=excluded_columns)   # With Drop NaN & Exclude Outlier with IQR

# Create dummy variables for categorical features and split the data into training and testing sets with default test size of 0.2
X_dummy, y = prepare_features(df, encode_dummies=True)
X_label, _ = prepare_features(df, encode_dummies=False)

In [6]:
# incase the sparse input under-perform
X_dummy_pca = PCA(n_components=config['PCA_component']).fit_transform(X_dummy)
# Convert back to DataFrame
X_dummy = pd.DataFrame(X_dummy_pca, columns=[f'PC{i+1}' for i in range(X_dummy_pca.shape[1])])

In [7]:
Xd_tr, Xd_te, Xl_tr, Xl_te, y_tr, y_te = split_data(X_dummy, X_label, y) 
# Xd: X features with dummy variables
# Xl: X features with label encoding (e.g. Turn R, G, B into 0, 1, 2)

In [8]:
class GPyTorchGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPyTorchGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

def fit_gpytorch_and_predict(X_train, y_train, X_pool, seed, training_iter=20, 
                             max_train_size=config['max_train_size']):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if len(X_train) > max_train_size:
        np.random.seed(seed)
        sub_idx = np.random.choice(len(X_train), max_train_size, replace=False)
        X_train = X_train[sub_idx]
        y_train = y_train[sub_idx]

    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train = torch.tensor(y_train, dtype=torch.float32).view(-1).to(device)
    X_pool = torch.tensor(X_pool, dtype=torch.float32).to(device)

    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    model = GPyTorchGPModel(X_train, y_train, likelihood).to(device)

    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for _ in range(training_iter):
        optimizer.zero_grad()
        output = model(X_train)
        loss = -mll(output, y_train)
        loss.backward()
        optimizer.step()

    model.eval()
    likelihood.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        preds = model(X_pool)
        mu = preds.mean.cpu().numpy()
        sigma = preds.stddev.cpu().numpy()

    return mu, sigma


def fit_rf_sklearn_and_predict(X_train, y_train, X_pool, seed, n_estimators=config['n_estimators'], 
                               max_train_size=config['max_train_size']):
    if len(X_train) > max_train_size:
        np.random.seed(seed)
        sub_idx = np.random.choice(len(X_train), max_train_size, replace=False).astype(int)
        X_train = X_train[sub_idx]
        y_train = y_train[sub_idx]

    rf = RandomForestRegressor(n_estimators=n_estimators, n_jobs=-1)
    rf.fit(X_train, y_train.ravel())

    # Collect predictions from all trees
    preds = np.stack([tree.predict(X_pool) for tree in rf.estimators_], axis=0)
    
    mu = preds.mean(axis=0)
    sigma = preds.std(axis=0)

    return mu, sigma


''' Does not work on windows!
def fit_cuml_rf_and_predict(X_train_np, y_train_np, X_pool_np, n_estimators=100):
    rf_model = cuRF(n_estimators=n_estimators)
    rf_model.fit(X_train_np, y_train_np)

    # predict from each tree for uncertainty estimation
    preds = np.stack([tree.predict(X_pool_np) for tree in rf_model.base_models_], axis=0)
    mu = preds.mean(axis=0)
    sigma = preds.std(axis=0)
    return mu, sigma
'''


def fit_bayes_linear_and_predict(X_train, y_train, X_pool, seed, 
                                 max_train_size=config['max_train_size']):
    if len(X_train) > max_train_size:
        np.random.seed(seed)
        sub_idx = np.random.choice(len(X_train), max_train_size, replace=False).astype(int)
        X_train = X_train[sub_idx]
        y_train = y_train[sub_idx]
    
    model = BayesianRidge()
    model.fit(X_train, y_train.ravel())
    mu, sigma = model.predict(X_pool, return_std=True)
    return mu, sigma


def enable_dropout(nn_model):
    """Enable dropout during inference."""
    for m in nn_model.modules():
        if isinstance(m, nn.Dropout):
            m.train()


def fit_mc_dropout_and_predict(nn_model, X, T=config['mc_droupout_T']):
    nn_model.eval()
    enable_dropout(nn_model)

    device = next(nn_model.parameters()).device

    X = torch.tensor(X, dtype=torch.float32).to(device)

    preds = []

    with torch.no_grad():
        for _ in range(T):
            pred = nn_model(X)
            preds.append(pred.cpu().numpy())

    preds = np.stack(preds, axis=0)  # shape: (T, batch_size, 1)
    mu = preds.mean(axis=0).squeeze()   # mean prediction
    sigma = preds.std(axis=0).squeeze() # uncertainty (std)

    return mu, sigma

#mu, sigma = fit_gpytorch_and_predict(X_train, y_train, X_pool)

In [9]:
class RegressionDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).view(-1, 1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


'''Define Acquisition Methods'''
# Passive acquisition (random sampling)
def select_next_batch_passive(X_pool, batch_size):
    selected = np.random.choice(len(X_pool), batch_size, replace=False)
    return selected
    

# Greedy acquisition
def select_next_batch_greedy(X_pool, y_pool, mu, sigma, batch_size):
    best_y = np.min(y_pool)
    improvement = best_y - mu
    Z = improvement / (sigma + 1e-8)
    ei = improvement * norm.cdf(Z) + sigma * norm.pdf(Z)
    selected = np.argsort(ei)[-batch_size:]
    return selected
    

# MCMC acquisition
def select_next_batch_mcmc(X_pool, mu, sigma, batch_size, steps=config['mcmc_steps']):
    scores = mu + np.random.normal(0, sigma, size=(steps, len(X_pool)))   # sample from N steps times
    selected = np.argsort(np.mean(scores, axis=0))[-batch_size:]
    return selected


# Thompson Sampling
def select_next_batch_thompson(X_pool, mu, sigma, batch_size):
    samples = np.random.normal(mu, sigma)
    selected = np.argsort(samples)[-batch_size:]
    return selected
    

# Metropolis-Hastings acquisition
def select_next_batch_metropolis(X_pool, mu, sigma, batch_size, steps=config['metropolis_steps']):
    n = len(mu)
    selected = set()
    available_indices = set(range(n))

    # Initialize with a random index
    current_idx = np.random.choice(list(available_indices))
    selected.add(current_idx)
    available_indices.remove(current_idx)

    for _ in range(steps):
        if len(selected) >= batch_size:
            break
        if not available_indices:
            break  # pool exhausted

        candidate_idx = np.random.choice(list(available_indices))

        sample_current = np.random.normal(mu[current_idx], sigma[current_idx])
        sample_candidate = np.random.normal(mu[candidate_idx], sigma[candidate_idx])

        # Favor candidates with lower sample values
        p_accept = min(1, (sample_current + 1e-6) / (sample_candidate + 1e-6))

        if np.random.rand() < p_accept:
            selected.add(candidate_idx)
            available_indices.remove(candidate_idx)
            current_idx = candidate_idx  # move to accepted

    return np.array(list(selected))


'''Define Learning Loop'''
def gp_active_learning_loop(X, y, seed, nn_model, initial_ratio=0.01, batch_size=256, acquisition=config['acquisition']):
    n_samples = X.shape[0]
    initial_size = int(initial_ratio * n_samples)
    all_indices = np.arange(n_samples)
    np.random.shuffle(all_indices)
    selected_indices = list(all_indices[:initial_size])
    remaining_indices = all_indices[initial_size:]

    total_batches = len(remaining_indices) // batch_size

    with tqdm(total=total_batches, desc=f"Active Learning ({acquisition})") as pbar:
    #for _ in range(total_batches):
        while len(remaining_indices) >= batch_size:
            subsample_size = min(config.get('subsample_pool_size', 10000), len(remaining_indices))
            sub_pool_rel_indices = np.random.choice(len(remaining_indices), subsample_size, replace=False)
            sub_pool_abs_indices = remaining_indices[sub_pool_rel_indices]

            X_pool = X[sub_pool_abs_indices]
            y_pool = y[sub_pool_abs_indices]
            if len(remaining_indices) < batch_size:
                break

            X_train, y_train = X[selected_indices], y[selected_indices]
            # define GP model here for acquisiton methods
            if acquisition != 'passive':
                if config['surrogate'] == 'gp':
                    mu, sigma = fit_gpytorch_and_predict(X_train, y_train, X_pool, seed,
                                                         training_iter=config['training_iter'],
                                                         max_train_size=config['max_train_size'])

                elif config['surrogate'] == 'rf':
                    mu, sigma = fit_rf_sklearn_and_predict(X_train, y_train, X_pool, seed,
                                                           n_estimators=config['n_estimators'],
                                                           max_train_size=config['max_train_size'])

                elif config['surrogate'] == 'bayes_linear':
                    mu, sigma = fit_bayes_linear_and_predict(X_train, y_train, X_pool, seed,
                                                             max_train_size=config['max_train_size'])

                elif config['surrogate'] == 'mc_dropout':  # must use X_pool here
                    mu, sigma = fit_mc_dropout_and_predict(nn_model, X_pool, T=config['mc_droupout_T'])

                else:
                    raise ValueError(f"Unknown surrogate model: {config['surrogate']}")


            if acquisition == 'passive':
                rel_indices = select_next_batch_passive(X_pool, batch_size)  # no need gp_model
            elif acquisition == 'greedy':
                rel_indices = select_next_batch_greedy(X_pool, y_pool, mu, sigma, batch_size)
            elif acquisition == 'mcmc':
                rel_indices = select_next_batch_mcmc(X_pool, mu, sigma, batch_size)
            elif acquisition == 'thompson':
                rel_indices = select_next_batch_thompson(X_pool, mu, sigma, batch_size)
            elif acquisition == 'metropolis':
                rel_indices = select_next_batch_metropolis(X_pool, mu, sigma, batch_size)
                
            else:
                raise ValueError('unknown acquisiton type')
            #print(len(rel_indices))
            if len(rel_indices) < config['batch_size']:
                print(f"warning, undersampling detected: {len(rel_indices)} number sampled")

            #print("X_pool.shape:", X_pool.shape)
            #print("mu.shape:", mu.shape)
            #print("sigma.shape:", sigma.shape)
            #print("rel_indices.max():", rel_indices.max())
        
            abs_indices = sub_pool_abs_indices[rel_indices]

            selected_indices.extend(abs_indices)
            remaining_indices = np.setdiff1d(remaining_indices, abs_indices)

            pbar.update(1)

    X_final = X[selected_indices]
    y_final = y[selected_indices]
    dataset = RegressionDataset(X_final, y_final)

    # we do not want to shuffle in this case since we just made and ordered the batches based on acquisition
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [10]:
# Simple sparse-aware MLP
class SparseRegressor(nn.Module):
    def __init__(self, input_dim, dropout_rate=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x)

In [11]:
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Convert to tensors and normalize
Xd_tr_tensor = torch.tensor(Xd_tr.to_numpy(), dtype=torch.float32)
Xd_te_tensor = torch.tensor(Xd_te.to_numpy(), dtype=torch.float32)
y_tr_tensor = torch.tensor(y_tr.to_numpy(), dtype=torch.float32).unsqueeze(1)
y_te_tensor = torch.tensor(y_te.to_numpy(), dtype=torch.float32).unsqueeze(1)

scaler = StandardScaler()
Xd_tr_tensor = torch.tensor(scaler.fit_transform(Xd_tr_tensor), dtype=torch.float32)
Xd_te_tensor = torch.tensor(scaler.transform(Xd_te_tensor), dtype=torch.float32)

# Dataset and DataLoader
train_dataset = TensorDataset(Xd_tr_tensor, y_tr_tensor)
test_dataset = TensorDataset(Xd_te_tensor, y_te_tensor)

# Split into train and val
val_ratio = config['val_ratio']
val_size = int(len(train_dataset) * val_ratio)
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Unpack tensors from train_dataset (already normalized + tensorized)
train_X_tensor, train_y_tensor = zip(*train_dataset)

# Stack and convert to NumPy arrays for GP input
Xd_tr_np = torch.stack(train_X_tensor).cpu().numpy()
y_tr_np = torch.stack(train_y_tensor).squeeze().cpu().numpy()

print(Xd_tr_np.shape)

#train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
# Do not shuffle, the randomness came from inside the loop.
# Need to reload the model each epoch so the batch order is not part of the learning.
'''train_loader = gp_active_learning_loop(Xd_tr_np, y_tr_np,  # ← Xd_tr and y_tr as NumPy arrays
                                       config['random_seed'],
                                       initial_ratio=config['initial_ratio'],
                                       batch_size=config['batch_size'],
                                       acquisition=config['acquisition'])'''

val_loader   = DataLoader(val_dataset, batch_size=config['batch_size'])
test_loader  = DataLoader(test_dataset, batch_size=config['batch_size'])

(106652, 70)


In [12]:
'''
ConvergenceWarning: The optimal value found for dimension 0 of parameter length_scale is close to the specified lower bound 1e-10. 
Decreasing the bound and calling fit again may find a better value.
'''



In [13]:
# Model setup
input_dim = Xd_tr.shape[1]
nn_model = SparseRegressor(input_dim, dropout_rate=config['dropout']).to(device)
optimizer = torch.optim.Adam(nn_model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
criterion = nn.MSELoss()

summary(nn_model, input_size=(input_dim,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]          18,176
       BatchNorm1d-2                  [-1, 256]             512
              GELU-3                  [-1, 256]               0
           Dropout-4                  [-1, 256]               0
            Linear-5                  [-1, 128]          32,896
       BatchNorm1d-6                  [-1, 128]             256
              GELU-7                  [-1, 128]               0
           Dropout-8                  [-1, 128]               0
            Linear-9                   [-1, 64]           8,256
             GELU-10                   [-1, 64]               0
           Linear-11                    [-1, 1]              65
Total params: 60,161
Trainable params: 60,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/ba

In [14]:
def compute_rmse(pred, true):  # do not use this since loss is already MSE
    return torch.sqrt(F.mse_loss(pred, true)).item()


def compute_r2(pred, true):
    ss_res = torch.sum((true - pred) ** 2)
    ss_tot = torch.sum((true - torch.mean(true)) ** 2)
    r2 = 1 - ss_res / ss_tot
    return r2.item()
    

def train_model(nn_model, val_loader, test_loader, epochs=30):
    
    train_losses, val_losses, test_losses = [], [], []
    train_r2s, val_r2s, test_r2s = [], [], []

    for epoch in range(epochs):

        seed = config['random_seed'] + epoch  # different seed per epoch

        train_loader = gp_active_learning_loop(
            Xd_tr_np, y_tr_np,
            seed,
            nn_model,
            initial_ratio=config['initial_ratio'],
            batch_size=config['batch_size'],
            acquisition=config['acquisition']
        )
        
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training", leave=False)
        for x_batch, y_batch in train_loop:
            nn_model.train()
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            preds = nn_model(x_batch)
            loss = criterion(preds, y_batch)
            loss.backward()
            optimizer.step()

            r2 = compute_r2(preds, y_batch)
            train_losses.append(loss.item())
            train_r2s.append(r2)
            train_loop.set_postfix(loss=loss.item(), r2=r2)

            # Validation (per batch)
            nn_model.eval()
            #val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validation", leave=False)
            total_val_loss = 0.0
            total_val_r2 = 0.0
            with torch.no_grad():
                #for x_val, y_val in val_loop:
                for x_val, y_val in val_loader:
                    x_val, y_val = x_val.to(device), y_val.to(device)
                    preds = nn_model(x_val)
                    loss = criterion(preds, y_val)
                    r2 = compute_r2(preds, y_val)
                    total_val_loss += loss.item() * x_val.size(0)
                    total_val_r2 += r2 * x_val.size(0)

            avg_val_loss = total_val_loss / len(val_loader.dataset)
            avg_val_r2 = total_val_r2 / len(val_loader.dataset)
            val_losses.append(avg_val_loss)
            val_r2s.append(avg_val_r2)
            #val_loop.set_postfix(loss=loss.item(), r2=r2)

            # Test (per batch)
            #test_loop = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} - Test", leave=False)
            total_test_loss = 0.0
            total_test_r2 = 0.0
            with torch.no_grad():
                #for x_test, y_test in test_loop:
                for x_test, y_test in test_loader:
                    x_test, y_test = x_test.to(device), y_test.to(device)
                    preds = nn_model(x_test)
                    loss = criterion(preds, y_test)
                    r2 = compute_r2(preds, y_test)
                    total_test_loss += loss.item() * x_test.size(0)
                    total_test_r2 += r2 * x_test.size(0)

            avg_test_loss = total_test_loss / len(test_loader.dataset)
            avg_test_r2 = total_test_r2 / len(test_loader.dataset)
            test_losses.append(avg_test_loss)
            test_r2s.append(avg_test_r2)
            #test_loop.set_postfix(loss=loss.item(), r2=r2)

        print(f"Epoch {epoch+1:02d} | "
              f"Last Batch -> Train Loss: {train_losses[-1]:.4f}, R^2: {train_r2s[-1]:.4f} | "
              f"Val Loss: {val_losses[-1]:.4f}, R^2: {val_r2s[-1]:.4f} | "
              f"Test Loss: {test_losses[-1]:.4f}, R^2: {test_r2s[-1]:.4f}")

    return {
        "train_loss": train_losses, "val_loss": val_losses, "test_loss": test_losses,
        "train_r2": train_r2s, "val_r2": val_r2s, "test_r2": test_r2s
    }

In [None]:
results = train_model(nn_model, val_loader, test_loader, epochs=config['epochs'])

Active Learning (metropolis): 100%|████████████████████████████████████████████████████| 51/51 [01:13<00:00,  1.45s/it]
Epoch 1/10 - Training:  27%|█████████▉                           | 14/52 [00:23<00:58,  1.54s/it, loss=8.86, r2=-0.306]

In [None]:
# each epoch has 52 batches for batch size 2048

os.makedirs(f"results_{config['random_seed']}", exist_ok=True)

np.savez(f"results_{config['random_seed']}/history_{config['surrogate']}_{config['acquisition']}_rnd{config['random_seed']}.npz", **results)

In [None]:
torch.save(nn_model.state_dict(), f"results_{config['random_seed']}/model_{config['surrogate']}_{config['acquisition']}_rnd{config['random_seed']}.pt")

'''model = SparseRegressor(input_dim).to(device)
model.load_state_dict(torch.load(f"checkpoints/model_{config['acquisition']}_seed{config['random_seed']}.pt"))
model.eval()'''

In [None]:
loaded = np.load(f"results_{config['random_seed']}/history_{config['surrogate']}_{config['acquisition']}_rnd{config['random_seed']}.npz")
print(loaded["train_r2"].shape)

results = {k: list(loaded[k]) for k in loaded.files}

In [None]:
steps = list(range(1, len(results["train_loss"]) + 1))

# Get min values and corresponding steps
def get_min_metric(metric):
    value = min(metric)
    steps = metric.index(value) + 1
    return value, steps


# Get max values and corresponding steps
def get_max_metric(metric):
    value = max(metric)
    steps = metric.index(value) + 1
    return value, steps


min_train_loss, ep_train_loss = get_min_metric(results["train_loss"])
min_val_loss, ep_val_loss     = get_min_metric(results["val_loss"])
min_test_loss, ep_test_loss   = get_min_metric(results["test_loss"])

max_train_r2, ep_train_r2 = get_max_metric(results["train_r2"])
max_val_r2, ep_val_r2     = get_max_metric(results["val_r2"])
max_test_r2, ep_test_r2   = get_max_metric(results["test_r2"])

plt.figure(figsize=(16, 12))

# --- Loss Plot ---
plt.subplot(2, 1, 1)
plt.plot(steps, results["train_loss"], label="Train")
plt.plot(steps, results["val_loss"], label="Val")
plt.plot(steps, results["test_loss"], label="Test")

# Mark min points for Loss
plt.scatter(ep_train_loss, min_train_loss, color='green', label=f"Min Train: {min_train_loss:.4f}")
plt.text(ep_train_loss, min_train_loss, f"E{ep_train_loss}", color='green', va='top', ha='right')

plt.scatter(ep_val_loss, min_val_loss, color='red', label=f"Min Val: {min_val_loss:.4f}")
plt.text(ep_val_loss, min_val_loss, f"E{ep_val_loss}", color='red', va='top', ha='right')

plt.scatter(ep_test_loss, min_test_loss, color='blue', label=f"Min Test: {min_test_loss:.4f}")
plt.text(ep_test_loss, min_test_loss, f"E{ep_test_loss}", color='blue', va='top', ha='right')

plt.title(f"{config['surrogate']} {config['acquisition']} rnd {config['random_seed']} Loss (MSE)")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()

# --- R^2 Plot ---
plt.subplot(2, 1, 2)
plt.plot(steps, results["train_r2"], label="Train")
plt.plot(steps, results["val_r2"], label="Val")
plt.plot(steps, results["test_r2"], label="Test")

# Mark max points for R^2
plt.scatter(ep_train_r2, max_train_r2, color='green', label=f"Max Train: {max_train_r2:.4f}")
plt.text(ep_train_r2, max_train_r2, f"E{ep_train_r2}", color='green', va='top', ha='right')

plt.scatter(ep_val_r2, max_val_r2, color='red', label=f"Max Val: {max_val_r2:.4f}")
plt.text(ep_val_r2, max_val_r2, f"E{ep_val_r2}", color='red', va='top', ha='right')

plt.scatter(ep_test_r2, max_test_r2, color='blue', label=f"Max Test: {max_test_r2:.4f}")
plt.text(ep_test_r2, max_test_r2, f"E{ep_test_r2}", color='blue', va='top', ha='right')

plt.title(f"{config['surrogate']} {config['acquisition']} rnd {config['random_seed']} R^2")
plt.xlabel("Step")
plt.ylabel("R^2")
plt.legend()

plt.tight_layout()
plt.show()