<a href="https://colab.research.google.com/github/MZiaAfzal71/Average_Weighted_Path_Vector/blob/main/Data%20Files/Models/ChemBERTaFiLMFusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/MZiaAfzal71/Average_Weighted_Path_Vector.git
%cd Average_Weighted_Path_Vector/Data\ Files

In [None]:
!pip install osfclient
from osfclient.api import OSF
import os
from subprocess import run

# Replace with your OSF project ID
project_id = "p5ga2"   # e.g. from https://osf.io/abcd3/
osf = OSF()
project = osf.project(project_id)
store = project.storage("osfstorage")

desc_folder = []
for fold in store.folders:
    if fold.path.strip("/") == "Descriptors Data":
        desc_folder = fold
        break

# Download all files and keep folder structure
for f in desc_folder.files:
    local_path = f.path.strip("/")            # keep folders
    local_dir = os.path.dirname(local_path)   # extract dir
    if local_dir and not os.path.exists(local_dir):
        os.makedirs(local_dir, exist_ok=True) # create dirs if missing
    with open(local_path, "wb") as out:
        f.write_to(out)
    if local_path.endswith(".zip"):
      command = f"unzip '{local_path}' -d '{local_dir}'"
      run(command, shell=True)
      print(f"\nUnzipped {local_path} -> {local_dir}")
      continue
    print(f"Downloaded {f.path} -> {local_path}")

In [None]:
# ============================
# ChemBERTa + Descriptor Gated Training Pipeline
# ============================
from __future__ import annotations
import os, random, math, json, gc
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple

import csv
from pathlib import Path
import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

from tqdm.auto import tqdm

In [None]:
# ----------------------------
# Config
# ----------------------------
@dataclass
class Config:
    model_name: str = "seyonec/ChemBERTa-zinc-base-v1"
    log_path: str = "training_log.csv"
    output_dir: str = "./chemberta_gated_out"
    save_path: str = "best_model.pt"
    max_length: int = 128
    batch_size: int = 16
    epochs: int = 5
    lr_backbone: float = 1e-5
    lr_heads: float = 1e-4
    weight_decay: float = 0.01
    seed: int = 42
    hidden_fuse: int = 512
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    proj_dim: int = 128     # descriptor projection dimension before fusion
    dropout: float = 0.1
    train_layers: int = 2   # unfreeze last N transformer blocks; 0 = all frozen
    gate_temp: float = 1.0
    p_moddrop: float = 0.2
    warmup_ratio: float = 0.1
    grad_clip: float = 1.0
    lambda_aux: float = 0.2
    lambda_div: float = 0.05
    return_numpy: bool = True

# ----------------------------
# Utils
# ----------------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def rmse(y_true, y_pred):
    return float(np.sqrt(mean_squared_error(y_true, y_pred)))

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

# ----------------------------
# Model (gated fusion)
# ----------------------------

class ChemBERTaFusion(nn.Module):
    """
    ChemBERTa + descriptors with:
      - Modality dropout (forces both branches to carry signal)
      - Data-dependent vector gate g = sigmoid(MLP([cls, desc_h]))
      - FiLM conditioning of CLS with descriptors
      - Auxiliary heads for CLS-only and DESC-only predictions
      - Diversity regularizer (cosine similarity penalty)
    """
    def __init__(self, model_name: str, n_desc: int,
                 proj_dim: int = 128, hidden_fuse: int = 512,
                 dropout: float = 0.1, train_layers: int = 0,
                 gate_temp: float = 1.0,
                 p_moddrop: float = 0.2):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        H = self.backbone.config.hidden_size
        self.n_desc = n_desc
        self.gate_temp = gate_temp
        self.p_moddrop = p_moddrop

        # Freeze backbone, optionally unfreeze tail
        # for p in self.backbone.parameters():
        #     p.requires_grad = False
        # if train_layers and hasattr(self.backbone, "encoder"):
        #     for layer in self.backbone.encoder.layer[-train_layers:]:
        #         for p in layer.parameters():
        #             p.requires_grad = True

        # Descriptor projection -> hidden size
        self.desc_proj = nn.Sequential(
            nn.Linear(n_desc, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(proj_dim, H),
            nn.LayerNorm(H),
            nn.ReLU(),
        )

        # Data-dependent gate: [cls, desc] -> H
        self.gate_mlp = nn.Sequential(
            nn.Linear(2*H, hidden_fuse),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_fuse, H)
        )

        # FiLM conditioning: desc -> (gamma, beta) in R^H
        self.film = nn.Sequential(
            nn.Linear(H, 2*H)
        )

        self.dropout = nn.Dropout(dropout)

        # Main head on fused features
        self.head_main = nn.Sequential(
            nn.Linear(H, H//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(H//2, 1),
        )

        # Aux heads to force each branch to be predictive
        self.head_cls  = nn.Linear(H, 1)
        self.head_desc = nn.Linear(H, 1)

    def forward(self, input_ids, attention_mask, descriptors,
                targets=None, lambda_aux=0.2, lambda_div=0.05):
        """
        Returns:
          pred: [B]
          loss (if targets given)
          diagnostics dict
        """
        B = input_ids.size(0)
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]  # [B,H]

        desc_h = self.desc_proj(descriptors)  # [B,H]

        # --- Modality dropout (train-time) ---
        if self.training and self.p_moddrop > 0.0:
            mask_choice = torch.empty(B, 1, device=cls.device).uniform_(0, 1)
            drop_cls  = (mask_choice < self.p_moddrop).float()          # with prob p, drop CLS
            drop_desc = ((mask_choice > 1-self.p_moddrop)).float()      # with prob p, drop DESC
            cls  = cls  * (1.0 - drop_cls)
            desc_h = desc_h * (1.0 - drop_desc)

        # --- FiLM conditioning ---
        gamma_beta = self.film(desc_h)           # [B,2H]
        gamma, beta = torch.chunk(gamma_beta, 2, dim=-1)
        cls_mod = (1.0 + gamma) * cls + beta     # reshaped CLS

        # --- Data-dependent vector gate ---
        gate_logits = self.gate_mlp(torch.cat([cls_mod, desc_h], dim=-1)) / self.gate_temp
        g = torch.sigmoid(gate_logits)           # [B,H]

        fused = g * cls_mod + (1.0 - g) * desc_h
        fused = self.dropout(fused)

        y_main = self.head_main(fused).squeeze(-1)
        y_cls  = self.head_cls(cls_mod).squeeze(-1)
        y_desc = self.head_desc(desc_h).squeeze(-1)

        loss = None
        diag = {}
        if targets is not None:
            targets = targets.float()
            l_main = F.mse_loss(y_main, targets)
            l_cls  = F.mse_loss(y_cls, targets)
            l_desc = F.mse_loss(y_desc, targets)

            # Diversity penalty: discourage aligned features
            # (cosine sim averaged over batch)
            eps = 1e-8
            c_sim = F.cosine_similarity(
                F.normalize(cls_mod, dim=-1, eps=eps),
                F.normalize(desc_h, dim=-1, eps=eps), dim=-1
            ).mean()

            loss = l_main + lambda_aux * (l_cls + l_desc) + lambda_div * c_sim

            diag = {
                "loss_main": l_main.detach().item(),
                "loss_cls":  l_cls.detach().item(),
                "loss_desc": l_desc.detach().item(),
                "cos_sim":   c_sim.detach().item(),
                "gate_mean": g.mean().detach().item()
            }

        return y_main, loss, diag


# ----------------------------
# Dataset / Collate
# ----------------------------
class SmiDescDataset(Dataset):
    def __init__(self, smiles: List[str], targets: Optional[np.ndarray],
                 tokenizer: AutoTokenizer, max_length: int,
                 descriptors: Optional[np.ndarray] = None):
        self.smiles = list(smiles)
        self.targets = None if targets is None else np.asarray(targets, dtype=np.float32)
        self.tok = tokenizer
        self.max_length = max_length
        self.desc = None if descriptors is None else np.asarray(descriptors, dtype=np.float32)

    def __len__(self): return len(self.smiles)

    def __getitem__(self, i):
        enc = self.tok(self.smiles[i],
                       truncation=True, padding="max_length",
                       max_length=self.max_length, return_tensors="pt")
        item = {k: v.squeeze(0) for k, v in enc.items()}
        if self.targets is not None:
            item["labels"] = torch.tensor(self.targets[i], dtype=torch.float32)
        if self.desc is not None:
            # ensure finite floats
            d = self.desc[i]
            if not np.all(np.isfinite(d)):  # replace bad values
                d = np.nan_to_num(d, nan=0.0, posinf=0.0, neginf=0.0)
            item["descriptors"] = torch.tensor(d, dtype=torch.float32)
        return item

def collate_stack(batch):
    out = {k: torch.stack([b[k] for b in batch]) for k in batch[0] if k != "labels"}
    if "labels" in batch[0]:
        out["labels"] = torch.stack([b["labels"] for b in batch])
    return out


def make_loaders(df: pd.DataFrame, target_col: str, tokenizer: AutoTokenizer,
                 cfg: Config, scaler: Optional[StandardScaler],
                 desc_cols: Optional[List[str]]) -> Tuple[DataLoader, DataLoader, Optional[np.ndarray]]:
    # Split
    train_df = df[df["Training/Test"].str.strip().str.lower() == "training"].reset_index(drop=True)
    test_df  = df[df["Training/Test"].str.strip().str.lower() == "test"].reset_index(drop=True)

    # Descriptors 1(fit scaler1 on training only)
    train_desc = test_desc = None
    if desc_cols:
        if scaler is None:
            scaler = StandardScaler().fit(train_df[desc_cols].to_numpy(dtype=np.float32))
        train_desc = scaler.transform(train_df[desc_cols].to_numpy(dtype=np.float32))
        test_desc  = scaler.transform(test_df[desc_cols].to_numpy(dtype=np.float32))


    train_ds = SmiDescDataset(train_df["SMILES"].tolist(),
                              train_df[target_col].to_numpy(dtype=np.float32),
                              tokenizer, cfg.max_length, train_desc)
    test_ds  = SmiDescDataset(test_df["SMILES"].tolist(),
                              test_df[target_col].to_numpy(dtype=np.float32),
                              tokenizer, cfg.max_length, test_desc)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                              collate_fn=collate_stack)
    test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=collate_stack)
    return train_loader, test_loader, scaler.mean_ if desc_cols else None


def setup_optimizer_scheduler(model, train_dataloader, epochs, lr=2e-5, lr_head=1e-3, warmup_ratio=0.1):

    # separate ChemBERTa backbone vs fusion head params
    backbone_params, head_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if "backbone" in n:
            backbone_params.append(p)
        else:
            head_params.append(p)
    optimizer = torch.optim.AdamW([ {"params": backbone_params, "lr": lr},
                                   {"params": head_params, "lr": lr_head}, ],
                                  weight_decay=0.01)

    total_steps = len(train_dataloader) * epochs
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    return optimizer, scheduler


def train_model(model, train_loader, val_loader, optimizer, scheduler,
                device, epochs=10, grad_clip=1.0, lambda_aux=0.2, lambda_div=0.05,
                save_path="best_model.pt", log_path="training_log.csv"):
    """
    Full trainer loop for ChemBERTaFusion with richer CSV logging.
    """
    model.to(device)
    best_val = float("inf")

    # --- Prepare CSV log ---
    log_file = Path(log_path)
    write_header = not log_file.exists()

    with open(log_file, mode="a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow([
                "epoch", "train_loss", "val_loss",
                "train_loss_cls", "train_loss_desc",
                "val_loss_cls", "val_loss_desc",
                "cos_sim", "gate_mean", "val_mae"
            ])

    for epoch in range(1, epochs+1):
        # ---- TRAIN ----
        model.train()
        train_loss, n_train = 0.0, 0
        train_loss_cls, train_loss_desc = 0.0, 0.0
        diag_accum = {"cos_sim": 0.0, "gate_mean": 0.0}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]")
        for batch in pbar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            descriptors = batch["descriptors"].to(device)
            targets = batch["labels"].to(device)

            preds, loss, diag = model(
                input_ids, attention_mask, descriptors,
                targets, lambda_aux=lambda_aux, lambda_div=lambda_div
            )

            loss.backward()
            if grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()

            bs = input_ids.size(0)
            train_loss += loss.item() * bs
            n_train += bs

            train_loss_cls += diag["loss_cls"] * bs
            train_loss_desc += diag["loss_desc"] * bs
            diag_accum["cos_sim"] += diag["cos_sim"] * bs
            diag_accum["gate_mean"] += diag["gate_mean"] * bs

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        train_loss /= n_train
        train_loss_cls /= n_train
        train_loss_desc /= n_train
        diag_accum = {k: v / n_train for k, v in diag_accum.items()}

        # ---- VALIDATION ----
        model.eval()
        val_loss, n_val = 0.0, 0
        val_loss_cls, val_loss_desc, val_mae = 0.0, 0.0, 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                descriptors = batch["descriptors"].to(device)
                targets = batch["labels"].to(device)

                preds, loss, diag = model(
                    input_ids, attention_mask, descriptors,
                    targets, lambda_aux=lambda_aux, lambda_div=lambda_div
                )

                bs = input_ids.size(0)
                val_loss += loss.item() * bs
                val_loss_cls += diag["loss_cls"] * bs
                val_loss_desc += diag["loss_desc"] * bs
                # MAE on predictions
                val_mae += F.l1_loss(preds, targets, reduction="sum").item()
                n_val += bs

        val_loss /= n_val
        val_loss_cls /= n_val
        val_loss_desc /= n_val
        val_mae /= n_val

        print(
            f"Epoch {epoch}: "
            f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, "
            f"train_cls={train_loss_cls:.4f}, val_cls={val_loss_cls:.4f}, "
            f"train_desc={train_loss_desc:.4f}, val_desc={val_loss_desc:.4f}, "
            f"cos_sim={diag_accum['cos_sim']:.3f}, gate_mean={diag_accum['gate_mean']:.3f}, "
            f"val_mae={val_mae:.4f}"
        )

        # ---- Log to CSV ----
        with open(log_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                epoch,
                train_loss, val_loss,
                train_loss_cls, train_loss_desc,
                val_loss_cls, val_loss_desc,
                diag_accum["cos_sim"], diag_accum["gate_mean"],
                val_mae
            ])

        # ---- Save best ----
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved best model (val_loss={val_loss:.4f})")

    print("Training complete.")
    model.load_state_dict(torch.load(save_path))
    return model



def predict(model, test_loader, device, return_numpy=True):
    """
    Run inference on a test set.

    Args:
      model: trained ChemBERTaFusionV2
      test_loader: DataLoader
      device: torch.device
      return_numpy: if True, returns numpy array

    Returns:
      preds: [N] predictions (torch.Tensor or np.ndarray)
    """
    model.eval()
    model.to(device)
    all_preds = []

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            descriptors = batch["descriptors"].to(device)

            preds, _, _ = model(input_ids, attention_mask, descriptors)
            all_preds.append(preds.cpu())

    preds = torch.cat(all_preds, dim=0)

    if return_numpy:
        return preds.numpy()
    return preds


def train_gated_for_prop_desc(prop: str,
                          desc: str,
                          cfg: Config) -> Dict[str, Any]:
    set_seed(cfg.seed)
    ensure_dir(cfg.output_dir)

    # ---- Load data
    target_col = f"{prop}-Measured"
    data_file = f"Descriptors Data/{prop}_{desc}.parquet"
    sheet_name = f"{prop}_{desc}"
    try:
        df = pd.read_parquet(data_file)
        desc_cols = df.columns[9:].to_list()
    except:
        raise ValueError(f"{data_file} is not found.")



    # ---- Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)

    scaler = None
    if desc_cols:
        scaler = StandardScaler()


    # ---- Data loaders
    train_loader, test_loader, _ = make_loaders(df, target_col, tokenizer, cfg, None, desc_cols)

    # Fit scaler explicitly (make_loaders will fit inside if None, but we also want to save it)
    if desc_cols:
        train_df = df[df["Training/Test"].str.strip().str.lower() == "training"].reset_index(drop=True)
        scaler = StandardScaler().fit(train_df[desc_cols].to_numpy(dtype=np.float32))
        scaler_path = os.path.join(cfg.output_dir, f"{sheet_name.replace(' ','_')}_scaler.pkl")
        # joblib is heavier; simple numpy save up_for Colab friendliness
        np.save(scaler_path.replace(".pkl", "_mean.npy"), scaler.mean_.astype(np.float32))
        np.save(scaler_path.replace(".pkl", "_scale.npy"), scaler.scale_.astype(np.float32))
        with open(scaler_path.replace(".pkl", "_cols.json"), "w") as f:
            json.dump(desc_cols, f)
        print(f"Saved descriptor scaler → {scaler_path} (+mean/scale/cols files) for the descriptor.")

    # ---- Model
    n_desc = len(desc_cols) if desc_cols else 0

    model = ChemBERTaFusion(cfg.model_name, n_desc=n_desc, proj_dim=cfg.proj_dim,
                              hidden_fuse=cfg.hidden_fuse, dropout=cfg.dropout,
                              train_layers=cfg.train_layers, gate_temp=cfg.gate_temp,
                             p_moddrop=cfg.p_moddrop).to(cfg.device)

    # Parameter groups (smaller LR for backbone; larger for fusion head/regressor)
    optimizer, scheduler = setup_optimizer_scheduler(model, train_loader, cfg.epochs,
                                                    cfg.lr_backbone, cfg.lr_heads,
                                                    cfg.warmup_ratio)

    save_model = os.path.join(cfg.output_dir, cfg.save_path)
    save_log = os.path.join(cfg.output_dir, f"{prop}_{desc}_{cfg.log_path}")
    model = train_model(model, train_loader, test_loader, optimizer, scheduler, cfg.device,
               cfg.epochs, cfg.grad_clip, cfg.lambda_aux, cfg.lambda_div, save_model, save_log)

    # Predict on all rows (Training + Test)
    # Prepare descriptors with saved scaler if available
    if desc_cols:
        mean = np.load(os.path.join(cfg.output_dir, f"{sheet_name.replace(' ','_')}_scaler_mean.npy"))
        scale = np.load(os.path.join(cfg.output_dir, f"{sheet_name.replace(' ','_')}_scaler_scale.npy"))
        with open(os.path.join(cfg.output_dir, f"{sheet_name.replace(' ','_')}_scaler_cols.json")) as f:
            saved_cols = json.load(f)
        # ensure consistent column order; fill missing with 0
        X = df.reindex(columns=saved_cols)
        X = X.to_numpy(dtype=np.float32)
        X = (X - mean) / np.where(scale == 0, 1.0, scale)
        desc = X
    else:
        desc = None

    tokenizer_fast = tokenizer  # reuse
    all_ds = SmiDescDataset(df["SMILES"].tolist(),
                            df[target_col].to_numpy(dtype=np.float32),
                            tokenizer_fast, cfg.max_length, desc)
    all_loader = DataLoader(all_ds, batch_size=cfg.batch_size, shuffle=False,
                            collate_fn=collate_stack)

    all_preds = predict(model, all_loader, cfg.device, cfg.return_numpy)

    # Build results DF
    new_results = pd.DataFrame({
        "Name": df["NAME"] if "NAME" in df.columns else pd.Series([None]*len(df)),
        "SMILES": df["SMILES"],
        "Observed": df[target_col],
        "Predicted": all_preds,
        "Training/Test": df["Training/Test"],
    })

    # Final metrics on Test only
    obs_test = new_results[new_results["Training/Test"].str.lower() == "test"]["Observed"].values
    pred_test = new_results[new_results["Training/Test"].str.lower() == "test"]["Predicted"].values
    mae_v = mean_absolute_error(obs_test, pred_test)
    rmse_v = rmse(obs_test, pred_test)
    r2_v = r2_score(obs_test, pred_test)
    print(f"Final (best) Test metrics for {sheet_name} → MAE: {mae_v:.4f} | RMSE: {rmse_v:.4f} | R²: {r2_v:.4f}")

    # Save predictions parquet
    pred_path = os.path.join(cfg.output_dir, f"chemberta_{sheet_name.replace(' ','_')}.parquet")
    new_results.to_parquet(pred_path, index=False)
    print(f"Saved predictions → {pred_path}")

    return {
        "sheet": sheet_name,
        "target_col": target_col,
        "best_path": save_model,
        "pred_path": pred_path,
        # "corr_report": corr_report_path,
        "MAE": mae_v, "RMSE": rmse_v, "R2": r2_v,
    }

# ----------------------------
# Predict helper (load & score new data)
# ----------------------------
def load_scaler_arrays(out_dir: str, sheet_name: str, desc_num: int):
    prefix = os.path.join(out_dir, f"{sheet_name.replace(' ','_')}_scaler")
    mean = np.load(prefix + f"_mean{desc_num}.npy")
    scale = np.load(prefix + f"_scale{desc_num}.npy")
    with open(prefix + f"_cols{desc_num}.json") as f:
        cols = json.load(f)
    return mean, scale, cols


# ----------------------------
# Multi-property runner
# ----------------------------
def run_all_properties_descriptors(prop_names: str, desc_names: List[str], cfg: Config):
    ensure_dir(cfg.output_dir)
    perf_rows = []
    for prop in prop_names:
        for desc in desc_names:
            print(f"\n=== Processing the file: {prop}_{desc} ===")
            result = train_gated_for_prop_desc(prop, desc, cfg)
            perf_rows.append([f"{prop}_{desc}", result["MAE"], result["RMSE"], result["R2"]])
    perf_df = pd.DataFrame(perf_rows, columns=["Property", "MAE", "RMSE", "R2"])
    stats_path = os.path.join(cfg.output_dir, f"chemberta_FiLM_Fusion_stats.csv")
    perf_df.to_csv(stats_path, index=False)
    print(f"\n📊 All-property stats saved → {stats_path}")
    return perf_df


In [None]:
cfg = Config(
    output_dir="chemberta_gated_results",
    epochs=30,
    batch_size=8,
    max_length=128,
    proj_dim=128,
    # train_layers=3,      # unfreeze last 2 transformer blocks
    lr_backbone=1e-5,
    lr_heads=1e-5
)

prop_names = ["Log VP", "MP", "BP", "LogBCF", "LogS", "LogP"]
# prop_names = ["LogP"]
desc_names = ["MACCS", "Morgan", "pwav"]

perf_df = run_all_properties_descriptors(prop_names, desc_names, cfg)
perf_df