In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT COMPLET (A à Z) : TRANSFORMATEUR CONDITIONNEL (Catégoriel) - V2 CORRIGÉ
Avec délimiteurs < et > pour les molécules dans le fichier
"""

# --- PARTIE 1 : IMPORTS ET CONFIGURATION ---

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import math
from dataclasses import dataclass
import time
import os
import json
from tqdm import tqdm
import gc # Garbage collector

try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from rdkit import rdBase
    rdBase.DisableLog('rdApp.error')
except ImportError:
    print("Erreur : RDKit n'est pas installé.")
    print("Veuillez l'installer avec : pip install rdkit")
    exit()

# --- Configuration ---
BATCH_SIZE = 32

BLOCK_SIZE = 128
N_EMBD = 128
N_HEAD = 4
N_LAYER = 4

MAX_ITERS = 10000 # Augmenter pour un meilleur entraînement
EVAL_INTERVAL = 500
LEARNING_RATE = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200

DROPOUT = 0.15
CONDITION_DIM = 3 # 3 catégories : LogP, MW, HBD

# Fichiers
DATA_FILE = "s_100_str_+1M_fixed.txt"  # À MODIFIER

VOCAB_FILE = 'vocab_dataset.json'
DATA_CACHE_FILE = 'data_cache_categorical_10K.pt' # Cache pour les données catégorielles

# Checkpoints
CHECKPOINT_DIR = 'checkpoints'
CHECKPOINT_FILE = os.path.join(CHECKPOINT_DIR, 'cond_gpt_categorical_10K.pth')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Configuration : Périphérique={DEVICE}, Batch={BATCH_SIZE}, Contexte={BLOCK_SIZE}")
torch.manual_seed(1337)

# --- PARTIE 2 : CONSTRUCTION DU VOCABULAIRE (Votre Code) ---

print(f"Construction du vocabulaire à partir de '{DATA_FILE}'...")

# Tokens spéciaux
PAD_TOKEN = '<pad>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'

# Construction du vocabulaire à partir du dataset
char_set = set()
with open(DATA_FILE, 'r') as f:
    for line in f:
        smiles = line.strip()
        # Extraire la molécule entre < et > si présente
        if '<' in smiles and '>' in smiles:
            start_idx = smiles.find('<') + 1
            end_idx = smiles.find('>')
            if start_idx < end_idx:
                smiles = smiles[start_idx:end_idx]
        char_set.update(list(smiles))

# Ajouter les tokens spéciaux
special_tokens = [PAD_TOKEN, START_TOKEN, END_TOKEN]
vocabulary = special_tokens + sorted(set(char_set))

# Dictionnaires pour encoder/décoder
stoi = { ch:i for i,ch in enumerate(vocabulary) }
itos = { i:ch for i,ch in enumerate(vocabulary) }
vocab_size = len(vocabulary)

# Sauvegarde en JSON
with open(VOCAB_FILE, 'w', encoding='utf-8') as f:
    json.dump({'stoi': stoi, 'itos': itos}, f, indent=2)

print(f"Taille vocabulaire : {vocab_size}")
print("Exemple de mapping :", list(stoi.items())[:10])

# Définir les fonctions globales encode/decode
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l if i in itos])

# Test du vocabulaire
print("\n=== TEST DU VOCABULAIRE ===")
test_smiles = "CCO"
print(f"Test encode/decode: {test_smiles}")
encoded = encode(test_smiles)
print(f"Encoded: {encoded}")
decoded = decode(encoded)
print(f"Decoded: {decoded}")
print(f"Match: {test_smiles == decoded}")

# --- PARTIE 3 : PRÉPARATION DES DONNÉES (Version Catégorielle) ---

# Fonction pour assigner une catégorie basée sur les intervalles
def get_category(value, bins):
    for i, upper_bound in enumerate(bins):
        if value <= upper_bound:
            return float(i) # Retourne l'indice de la catégorie en float
    return float(len(bins))

# Définir les bornes supérieures des intervalles (Ajustez si l'analyse a montré des distributions différentes)
LOGP_BINS = [0.0, 3.0, 5.0] # <=0(0), <=3(1), <=5(2), >5(3)
MW_BINS = [250.0, 480.0, 650.0] # <=250(0), <=480(1), <=650(2), >650(3)
HBD_BINS = [0.0, 1.0, 2.0, 3.0] # =0(0), =1(1), =2(2), =3(3), >3(4)

def load_and_process_data(filepath, stoi, max_len=BLOCK_SIZE, cache_file=DATA_CACHE_FILE):
    """
    Charge les SMILES, calcule les catégories pour LogP, MW, HBD et crée les tenseurs.
    Gère les molécules délimitées par < et >
    """
    if os.path.exists(cache_file):
        print(f"Chargement des données catégorielles depuis le cache '{cache_file}'...")
        data = torch.load(cache_file)
        print("Données chargées.")
        return data

    print("Traitement des données SMILES (calcul des catégories)...")
    print("(Cela peut prendre du temps sur tout le dataset)")
    data_processed = []
    pad_idx = stoi[PAD_TOKEN]
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    try:
        # Compter les lignes pour tqdm
        num_lines = sum(1 for line in open(filepath, 'r', encoding='utf-8'))
        with open(filepath, 'r') as f:
            for i, line in enumerate(tqdm(f, total=num_lines, desc="Calcul des catégories")):
                smiles = line.strip()
                
                # EXTRAIRE LA MOLÉCULE ENTRE < ET > SI PRÉSENTE
                if '<' in smiles and '>' in smiles:
                    start_delim = smiles.find('<') + 1
                    end_delim = smiles.find('>')
                    if start_delim < end_delim:
                        smiles = smiles[start_delim:end_delim]
                
                # Vérifier les caractères avant MolFromSmiles
                clean_smiles = smiles.strip()
                if len(clean_smiles) == 0:
                    continue
                # enlever caractères blancs
                clean_smiles = clean_smiles.replace(" ", "").replace("\t", "")
                # si après nettoyage longueur trop longue -> skip
                if len(clean_smiles) > max_len - 2:
                    continue
                # tester si au moins 80% des chars sont connus (optionnel)
                known_chars = sum(1 for c in clean_smiles if c in stoi)
                if known_chars / max(1, len(clean_smiles)) < 0.8:
                    # beaucoup de caractères inconnus -> ignorer
                    continue
                smiles = clean_smiles
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    continue

                try:
                    logp = Descriptors.MolLogP(mol)
                    mw = Descriptors.MolWt(mol)
                    hbd = Descriptors.NumHDonors(mol)

                    logp_cat = get_category(logp, LOGP_BINS)
                    mw_cat = get_category(mw, MW_BINS)
                    hbd_cat = get_category(hbd, HBD_BINS)

                    condition_vector = torch.tensor([logp_cat, mw_cat, hbd_cat], dtype=torch.float32)

                    token_ids = [start_idx] + encode(smiles) + [end_idx]
                    seq_len = len(token_ids)
                    x = torch.full((max_len,), pad_idx, dtype=torch.long)
                    y = torch.full((max_len,), pad_idx, dtype=torch.long)
                    x[:seq_len] = torch.tensor(token_ids, dtype=torch.long)
                    y[:seq_len-1] = torch.tensor(token_ids[1:], dtype=torch.long)

                    data_processed.append((x, y, condition_vector))

                except Exception as e:
                    # Gérer les erreurs RDKit spécifiques au calcul
                    continue

                # Appel périodique au garbage collector
                if i % 50000 == 0:
                    gc.collect()

    except FileNotFoundError:
        print(f"ERREUR: Le fichier de données '{filepath}' est introuvable.")
        exit()
    except MemoryError:
        print("\nERREUR: Manque de mémoire pour traiter tout le dataset.")
        print("Vous devrez peut-être utiliser un échantillon plus petit ou augmenter la RAM.")
        exit()


    print(f"\nNombre total de molécules valides chargées : {len(data_processed)}")
    print(f"Sauvegarde des données traitées dans le cache '{cache_file}'...")
    torch.save(data_processed, cache_file)
    print("Données sauvegardées.")

    return data_processed

class SMILESDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

# --- PARTIE 4 : ARCHITECTURE DU MODÈLE (BLOCS DE BASE - Inchangé) ---

@dataclass
class GPTConfig:
    block_size: int = BLOCK_SIZE
    vocab_size: int = vocab_size # Utilise la variable globale
    n_layer: int = N_LAYER
    n_head: int = N_HEAD
    n_embd: int = N_EMBD
    dropout: float = DROPOUT
    condition_dim: int = CONDITION_DIM

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


# --- PARTIE 5 : LE TRANSFORMATEUR CONDITIONNEL (Inchangé) ---

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))

        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"Séquence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)

        assert conditions is not None, "Les conditions doivent être fournies !"
        cond_emb = self.condition_projector(conditions)

        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=stoi[PAD_TOKEN])
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss


# --- PARTIE 6 : FONCTIONS DE CHECKPOINT ET D'ÉVALUATION (Inchangé) ---

def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, filepath):
    print(f"Sauvegarde du checkpoint dans {filepath}...")
    torch.save({
        'iter_num': iter_num,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'config': config,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    if not os.path.exists(filepath):
        print("Aucun checkpoint trouvé. Démarrage d'un nouvel entraînement.")
        return 0, float('inf')

    print(f"Chargement du checkpoint depuis {filepath}...")
    checkpoint = torch.load(filepath, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    print(f"Reprise de l'entraînement à l'itération {iter_num} (meilleure perte val: {best_val_loss:.4f})")
    return iter_num, best_val_loss

@torch.no_grad()
def estimate_loss(model, train_loader, val_loader, eval_iters=EVAL_ITERS):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        loader = train_loader if split == 'train' else val_loader
        losses = torch.zeros(eval_iters)
        loader_iter = iter(loader)
        for k in range(eval_iters):
            try:
                x, y, c = next(loader_iter)
            except StopIteration:
                loader_iter = iter(loader)
                x, y, c = next(loader_iter)

            x, y, c = x.to(DEVICE), y.to(DEVICE), c.to(DEVICE)
            logits, loss = model(x, y, c)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# --- FONCTION DE GÉNÉRATION CORRIGÉE ---
@torch.no_grad()
def generate_conditional(model, condition_tensor, stoi, itos, max_new_tokens=100, temperature=1.0, top_k=None):
    """
    Génère une séquence SMILES à partir d'un tenseur de condition CATÉGORIELLE.
    CORRECTION : Exclut le token <start> du décodage final.
    """
    model.eval()
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
    condition_tensor = condition_tensor.to(DEVICE)

    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]

        # Passe le tenseur de catégories directement
        logits, _ = model(idx_cond, conditions=condition_tensor)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        if idx_next.item() == end_idx:
            break

        idx = torch.cat((idx, idx_next), dim=1)

    model.train()

    # CORRECTION : Exclure le token de début et convertir les tokens en SMILES
    generated_tokens = idx[0].tolist()

    # Exclure le token <start> et s'arrêter au token <end> si présent
    if len(generated_tokens) > 1:
        # Commencer à partir du token après <start>
        tokens_to_decode = generated_tokens[1:]

        # S'arrêter au token <end> s'il est présent
        if end_idx in tokens_to_decode:
            end_pos = tokens_to_decode.index(end_idx)
            tokens_to_decode = tokens_to_decode[:end_pos]
    else:
        tokens_to_decode = []

    generated_smiles = decode(tokens_to_decode)

    return generated_smiles

# --- FONCTION DE VÉRIFICATION DES PROPRIÉTÉS ---
def check_mol_3_props(smiles):
    """ Vérifie les 3 propriétés réelles. """
    if not smiles:  # Vérifier si la chaîne est vide
        return "Vide", 0.0, 0.0, 0
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalide", 0.0, 0.0, 0
    logp = Descriptors.MolLogP(mol)
    mw = Descriptors.MolWt(mol)
    hbd = Descriptors.NumHDonors(mol)
    return "Valide", logp, mw, hbd

# --- PARTIE 7 : SCRIPT PRINCIPAL D'EXÉCUTION ---

if __name__ == "__main__":

    # 1. Vocabulaire - stoi, itos, vocab_size sont déjà globales

    # 2. Données (charge ou traite les données catégorielles)
    full_data = load_and_process_data(DATA_FILE, stoi, cache_file=DATA_CACHE_FILE)

    # 3. DataLoaders
    train_size = int(0.9 * len(full_data))
    val_size = len(full_data) - train_size
    train_data, val_data = torch.utils.data.random_split(full_data, [train_size, val_size])

    train_loader = DataLoader(SMILESDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(SMILESDataset(val_data), batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    # 4. Modèle et Optimiseur
    config = GPTConfig(vocab_size=vocab_size)
    model = ConditionalDrugGPT(config)
    model.to(DEVICE)

    print(f"Nombre de paramètres : {sum(p.numel() for p in model.parameters())/1e6:.2f} M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 5. Chargement du Checkpoint
    start_iter, best_val_loss = load_checkpoint(CHECKPOINT_FILE, model, optimizer)

    # 6. Boucle d'entraînement
    print(f"Début de l'entraînement sur {DEVICE}...")
    start_time = time.time()
    train_iter = iter(train_loader)

    for iter_num in range(start_iter, MAX_ITERS):

        if iter_num > 0 and (iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1):
            losses = estimate_loss(model, train_loader, val_loader)
            elapsed = time.time() - start_time
            print(f"Étape {iter_num}: perte train {losses['train']:.4f}, perte val {losses['val']:.4f}, temps {elapsed:.1f}s")

            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']
                save_checkpoint(model, optimizer, iter_num, best_val_loss, config, CHECKPOINT_FILE)

        try:
            xb, yb, cb = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            xb, yb, cb = next(train_iter)

        xb, yb, cb = xb.to(DEVICE), yb.to(DEVICE), cb.to(DEVICE)

        logits, loss = model(xb, targets=yb, conditions=cb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Appel périodique au garbage collector pendant l'entraînement
        if iter_num % 1000 == 0:
            gc.collect()

    print("Entraînement terminé !")

    # 7. Génération (Adaptée pour les catégories)
    print("\n--- Génération de Molécules Conditionnelles (Catégorielles) ---")

    # Test de génération simple avant les boucles
    print("\n=== TEST DE GÉNÉRATION SIMPLE ===")
    target_cats_test = [1.0, 1.0, 3.0]
    condition_tensor_test = torch.tensor(target_cats_test, dtype=torch.float32).unsqueeze(0)

    test_smiles = generate_conditional(model, condition_tensor_test, stoi, itos, max_new_tokens=50, top_k=10)
    valid, logp, mw, hbd = check_mol_3_props(test_smiles)
    print(f"Test génération: '{test_smiles}'")
    print(f"  Statut: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd}")

    # Condition 1: Viser la zone "Rule of 3"
    target_cats_1 = [1.0, 1.0, 3.0] # Indices des catégories souhaitées
    condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

    print(f"\nGénération pour Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")
    valid_count = 0
    for i in range(120):  # Réduit à 10 pour les tests
        mol_str = generate_conditional(model, condition_tensor_1, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd = check_mol_3_props(mol_str)
        print(f"  {i+1}. -> '{mol_str}'")
        print(f"     (Valide: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd})")

        if valid == "Valide":
            valid_count += 1
            logp_c = get_category(logp, LOGP_BINS)
            mw_c = get_category(mw, MW_BINS)
            hbd_c = get_category(hbd, HBD_BINS)
            print(f"     Catégories réelles: [LogP({logp_c:.0f}), MW({mw_c:.0f}), HBD({hbd_c:.0f})]")

    print(f"Molécules valides générées: {valid_count}/10")

    # Condition 2: Viser des molécules très lipophiles, grandes, avec peu de donneurs H
    target_cats_2 = [3.0, 3.0, 0.0]
    condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

    print(f"\nGénération pour Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")
    valid_count = 0
    for i in range(120):  # Réduit à 10 pour les tests
        mol_str = generate_conditional(model, condition_tensor_2, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd = check_mol_3_props(mol_str)
        print(f"  {i+1}. -> '{mol_str}'")
        print(f"     (Valide: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd})")

        if valid == "Valide":
            valid_count += 1
            logp_c = get_category(logp, LOGP_BINS)
            mw_c = get_category(mw, MW_BINS)
            hbd_c = get_category(hbd, HBD_BINS)
            print(f"     Catégories réelles: [LogP({logp_c:.0f}), MW({mw_c:.0f}), HBD({hbd_c:.0f})]")

    print(f"Molécules valides générées: {valid_count}/10")

Configuration : Périphérique=cuda, Batch=32, Contexte=128
Construction du vocabulaire à partir de 's_100_str_+1M_fixed.txt'...
Taille vocabulaire : 57
Exemple de mapping : [('<pad>', 0), ('<start>', 1), ('<end>', 2), ('#', 3), ('%', 4), ('(', 5), (')', 6), ('+', 7), ('-', 8), ('.', 9)]

=== TEST DU VOCABULAIRE ===
Test encode/decode: CCO
Encoded: [25, 25, 33]
Decoded: CCO
Match: True
Traitement des données SMILES (calcul des catégories)...
(Cela peut prendre du temps sur tout le dataset)


Calcul des catégories:  18%|████████▋                                        | 126453/710754 [01:39<08:02, 1209.81it/s][17:16:12] Conflicting single bond directions around double bond at index 6.
[17:16:12]   BondStereo set to STEREONONE and single bond directions set to NONE.
Calcul des catégories: 100%|█████████████████████████████████████████████████| 710754/710754 [08:53<00:00, 1331.32it/s]



Nombre total de molécules valides chargées : 710737
Sauvegarde des données traitées dans le cache 'data_cache_categorical_10K.pt'...
Données sauvegardées.
Nombre de paramètres : 0.82 M
Aucun checkpoint trouvé. Démarrage d'un nouvel entraînement.
Début de l'entraînement sur cuda...
Étape 500: perte train 1.2402, perte val 1.2450, temps 21.6s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical_10K.pth...
Étape 1000: perte train 1.0597, perte val 1.0693, temps 35.3s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical_10K.pth...
Étape 1500: perte train 0.9770, perte val 0.9801, temps 50.9s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical_10K.pth...
Étape 2000: perte train 0.9352, perte val 0.9344, temps 64.6s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical_10K.pth...
Étape 2500: perte train 0.8930, perte val 0.8953, temps 79.9s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical_10K.pth...
Étape 3000: perte train 0.8607, perte 

In [17]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT COMPLET (A à Z) : TRANSFORMATEUR CONDITIONNEL (Catégoriel) - V2 CORRIGÉ
Avec délimiteurs < et > pour les molécules dans le fichier
"""

# --- PARTIE 1 : IMPORTS ET CONFIGURATION ---

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import math
from dataclasses import dataclass
import time
import os
import json
from tqdm import tqdm
import gc # Garbage collector

try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from rdkit import rdBase
    rdBase.DisableLog('rdApp.error')
except ImportError:
    print("Erreur : RDKit n'est pas installé.")
    print("Veuillez l'installer avec : pip install rdkit")
    exit()

# --- Configuration ---
BATCH_SIZE = 32
BLOCK_SIZE = 128
MAX_ITERS = 5000 # Augmenter pour un meilleur entraînement
EVAL_INTERVAL = 500
LEARNING_RATE = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
N_EMBD = 128
N_HEAD = 4
N_LAYER = 4
DROPOUT = 0.1
CONDITION_DIM = 3 # 3 catégories : LogP, MW, HBD

# Fichiers
DATA_FILE = "s_100_str_+1M_fixed.txt"  # À MODIFIER

VOCAB_FILE = 'vocab_dataset.json'
DATA_CACHE_FILE = 'data_cache_categorical.pt' # Cache pour les données catégorielles

# Checkpoints
CHECKPOINT_DIR = 'checkpoints'
CHECKPOINT_FILE = os.path.join(CHECKPOINT_DIR, 'cond_gpt_categorical.pth')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Configuration : Périphérique={DEVICE}, Batch={BATCH_SIZE}, Contexte={BLOCK_SIZE}")
torch.manual_seed(1337)

# --- PARTIE 2 : CONSTRUCTION DU VOCABULAIRE (Votre Code) ---

print(f"Construction du vocabulaire à partir de '{DATA_FILE}'...")

# Tokens spéciaux
PAD_TOKEN = '<pad>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'

# Construction du vocabulaire à partir du dataset
char_set = set()
with open(DATA_FILE, 'r') as f:
    for line in f:
        smiles = line.strip()
        # Extraire la molécule entre < et > si présente
        if '<' in smiles and '>' in smiles:
            start_idx = smiles.find('<') + 1
            end_idx = smiles.find('>')
            if start_idx < end_idx:
                smiles = smiles[start_idx:end_idx]
        char_set.update(list(smiles))

# Ajouter les tokens spéciaux
special_tokens = [PAD_TOKEN, START_TOKEN, END_TOKEN]
vocabulary = special_tokens + sorted(set(char_set))

# Dictionnaires pour encoder/décoder
stoi = { ch:i for i,ch in enumerate(vocabulary) }
itos = { i:ch for i,ch in enumerate(vocabulary) }
vocab_size = len(vocabulary)

# Sauvegarde en JSON
with open(VOCAB_FILE, 'w', encoding='utf-8') as f:
    json.dump({'stoi': stoi, 'itos': itos}, f, indent=2)

print(f"Taille vocabulaire : {vocab_size}")
print("Exemple de mapping :", list(stoi.items())[:10])

# Définir les fonctions globales encode/decode
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l if i in itos])

# Test du vocabulaire
print("\n=== TEST DU VOCABULAIRE ===")
test_smiles = "CCO"
print(f"Test encode/decode: {test_smiles}")
encoded = encode(test_smiles)
print(f"Encoded: {encoded}")
decoded = decode(encoded)
print(f"Decoded: {decoded}")
print(f"Match: {test_smiles == decoded}")

# --- PARTIE 3 : PRÉPARATION DES DONNÉES (Version Catégorielle) ---

# Fonction pour assigner une catégorie basée sur les intervalles
def get_category(value, bins):
    for i, upper_bound in enumerate(bins):
        if value <= upper_bound:
            return float(i) # Retourne l'indice de la catégorie en float
    return float(len(bins))

# Définir les bornes supérieures des intervalles (Ajustez si l'analyse a montré des distributions différentes)
LOGP_BINS = [0.0, 3.0, 5.0] # <=0(0), <=3(1), <=5(2), >5(3)
MW_BINS = [250.0, 480.0, 650.0] # <=250(0), <=480(1), <=650(2), >650(3)
HBD_BINS = [0.0, 1.0, 2.0, 3.0] # =0(0), =1(1), =2(2), =3(3), >3(4)

def load_and_process_data(filepath, stoi, max_len=BLOCK_SIZE, cache_file=DATA_CACHE_FILE):
    """
    Charge les SMILES, calcule les catégories pour LogP, MW, HBD et crée les tenseurs.
    Gère les molécules délimitées par < et >
    """
    if os.path.exists(cache_file):
        print(f"Chargement des données catégorielles depuis le cache '{cache_file}'...")
        data = torch.load(cache_file)
        print("Données chargées.")
        return data

    print("Traitement des données SMILES (calcul des catégories)...")
    print("(Cela peut prendre du temps sur tout le dataset)")
    data_processed = []
    pad_idx = stoi[PAD_TOKEN]
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    try:
        # Compter les lignes pour tqdm
        num_lines = sum(1 for line in open(filepath, 'r', encoding='utf-8'))
        with open(filepath, 'r') as f:
            for i, line in enumerate(tqdm(f, total=num_lines, desc="Calcul des catégories")):
                smiles = line.strip()
                
                # EXTRAIRE LA MOLÉCULE ENTRE < ET > SI PRÉSENTE
                if '<' in smiles and '>' in smiles:
                    start_delim = smiles.find('<') + 1
                    end_delim = smiles.find('>')
                    if start_delim < end_delim:
                        smiles = smiles[start_delim:end_delim]
                
                # Vérifier les caractères avant MolFromSmiles
                clean_smiles = smiles.strip()
                if len(clean_smiles) == 0:
                    continue
                # enlever caractères blancs
                clean_smiles = clean_smiles.replace(" ", "").replace("\t", "")
                # si après nettoyage longueur trop longue -> skip
                if len(clean_smiles) > max_len - 2:
                    continue
                # tester si au moins 80% des chars sont connus (optionnel)
                known_chars = sum(1 for c in clean_smiles if c in stoi)
                if known_chars / max(1, len(clean_smiles)) < 0.8:
                    # beaucoup de caractères inconnus -> ignorer
                    continue
                smiles = clean_smiles
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    continue

                try:
                    logp = Descriptors.MolLogP(mol)
                    mw = Descriptors.MolWt(mol)
                    hbd = Descriptors.NumHDonors(mol)

                    logp_cat = get_category(logp, LOGP_BINS)
                    mw_cat = get_category(mw, MW_BINS)
                    hbd_cat = get_category(hbd, HBD_BINS)

                    condition_vector = torch.tensor([logp_cat, mw_cat, hbd_cat], dtype=torch.float32)

                    token_ids = [start_idx] + encode(smiles) + [end_idx]
                    seq_len = len(token_ids)
                    x = torch.full((max_len,), pad_idx, dtype=torch.long)
                    y = torch.full((max_len,), pad_idx, dtype=torch.long)
                    x[:seq_len] = torch.tensor(token_ids, dtype=torch.long)
                    y[:seq_len-1] = torch.tensor(token_ids[1:], dtype=torch.long)

                    data_processed.append((x, y, condition_vector))

                except Exception as e:
                    # Gérer les erreurs RDKit spécifiques au calcul
                    continue

                # Appel périodique au garbage collector
                if i % 50000 == 0:
                    gc.collect()

    except FileNotFoundError:
        print(f"ERREUR: Le fichier de données '{filepath}' est introuvable.")
        exit()
    except MemoryError:
        print("\nERREUR: Manque de mémoire pour traiter tout le dataset.")
        print("Vous devrez peut-être utiliser un échantillon plus petit ou augmenter la RAM.")
        exit()


    print(f"\nNombre total de molécules valides chargées : {len(data_processed)}")
    print(f"Sauvegarde des données traitées dans le cache '{cache_file}'...")
    torch.save(data_processed, cache_file)
    print("Données sauvegardées.")

    return data_processed

class SMILESDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

# --- PARTIE 4 : ARCHITECTURE DU MODÈLE (BLOCS DE BASE - Inchangé) ---

@dataclass
class GPTConfig:
    block_size: int = BLOCK_SIZE
    vocab_size: int = vocab_size # Utilise la variable globale
    n_layer: int = N_LAYER
    n_head: int = N_HEAD
    n_embd: int = N_EMBD
    dropout: float = DROPOUT
    condition_dim: int = CONDITION_DIM

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


# --- PARTIE 5 : LE TRANSFORMATEUR CONDITIONNEL (Inchangé) ---

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))

        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"Séquence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)

        assert conditions is not None, "Les conditions doivent être fournies !"
        cond_emb = self.condition_projector(conditions)

        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=stoi[PAD_TOKEN])
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss


# --- PARTIE 6 : FONCTIONS DE CHECKPOINT ET D'ÉVALUATION (Inchangé) ---

def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, filepath):
    print(f"Sauvegarde du checkpoint dans {filepath}...")
    torch.save({
        'iter_num': iter_num,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'config': config,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    if not os.path.exists(filepath):
        print("Aucun checkpoint trouvé. Démarrage d'un nouvel entraînement.")
        return 0, float('inf')

    print(f"Chargement du checkpoint depuis {filepath}...")
    checkpoint = torch.load(filepath, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    print(f"Reprise de l'entraînement à l'itération {iter_num} (meilleure perte val: {best_val_loss:.4f})")
    return iter_num, best_val_loss

@torch.no_grad()
def estimate_loss(model, train_loader, val_loader, eval_iters=EVAL_ITERS):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        loader = train_loader if split == 'train' else val_loader
        losses = torch.zeros(eval_iters)
        loader_iter = iter(loader)
        for k in range(eval_iters):
            try:
                x, y, c = next(loader_iter)
            except StopIteration:
                loader_iter = iter(loader)
                x, y, c = next(loader_iter)

            x, y, c = x.to(DEVICE), y.to(DEVICE), c.to(DEVICE)
            logits, loss = model(x, y, c)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# --- FONCTION DE GÉNÉRATION CORRIGÉE ---
@torch.no_grad()
def generate_conditional(model, condition_tensor, stoi, itos, max_new_tokens=100, temperature=1.0, top_k=None):
    """
    Génère une séquence SMILES à partir d'un tenseur de condition CATÉGORIELLE.
    CORRECTION : Exclut le token <start> du décodage final.
    """
    model.eval()
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
    condition_tensor = condition_tensor.to(DEVICE)

    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]

        # Passe le tenseur de catégories directement
        logits, _ = model(idx_cond, conditions=condition_tensor)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        if idx_next.item() == end_idx:
            break

        idx = torch.cat((idx, idx_next), dim=1)

    model.train()

    # CORRECTION : Exclure le token de début et convertir les tokens en SMILES
    generated_tokens = idx[0].tolist()

    # Exclure le token <start> et s'arrêter au token <end> si présent
    if len(generated_tokens) > 1:
        # Commencer à partir du token après <start>
        tokens_to_decode = generated_tokens[1:]

        # S'arrêter au token <end> s'il est présent
        if end_idx in tokens_to_decode:
            end_pos = tokens_to_decode.index(end_idx)
            tokens_to_decode = tokens_to_decode[:end_pos]
    else:
        tokens_to_decode = []

    generated_smiles = decode(tokens_to_decode)

    return generated_smiles

# --- FONCTION DE VÉRIFICATION DES PROPRIÉTÉS ---
def check_mol_3_props(smiles):
    """ Vérifie les 3 propriétés réelles. """
    if not smiles:  # Vérifier si la chaîne est vide
        return "Vide", 0.0, 0.0, 0
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalide", 0.0, 0.0, 0
    logp = Descriptors.MolLogP(mol)
    mw = Descriptors.MolWt(mol)
    hbd = Descriptors.NumHDonors(mol)
    return "Valide", logp, mw, hbd

# --- PARTIE 7 : SCRIPT PRINCIPAL D'EXÉCUTION ---

if __name__ == "__main__":

    # 1. Vocabulaire - stoi, itos, vocab_size sont déjà globales

    # 2. Données (charge ou traite les données catégorielles)
    full_data = load_and_process_data(DATA_FILE, stoi, cache_file=DATA_CACHE_FILE)

    # 3. DataLoaders
    train_size = int(0.9 * len(full_data))
    val_size = len(full_data) - train_size
    train_data, val_data = torch.utils.data.random_split(full_data, [train_size, val_size])

    train_loader = DataLoader(SMILESDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(SMILESDataset(val_data), batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    # 4. Modèle et Optimiseur
    config = GPTConfig(vocab_size=vocab_size)
    model = ConditionalDrugGPT(config)
    model.to(DEVICE)

    print(f"Nombre de paramètres : {sum(p.numel() for p in model.parameters())/1e6:.2f} M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 5. Chargement du Checkpoint
    start_iter, best_val_loss = load_checkpoint(CHECKPOINT_FILE, model, optimizer)

    # 6. Boucle d'entraînement
    print(f"Début de l'entraînement sur {DEVICE}...")
    start_time = time.time()
    train_iter = iter(train_loader)

    for iter_num in range(start_iter, MAX_ITERS):

        if iter_num > 0 and (iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1):
            losses = estimate_loss(model, train_loader, val_loader)
            elapsed = time.time() - start_time
            print(f"Étape {iter_num}: perte train {losses['train']:.4f}, perte val {losses['val']:.4f}, temps {elapsed:.1f}s")

            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']
                save_checkpoint(model, optimizer, iter_num, best_val_loss, config, CHECKPOINT_FILE)

        try:
            xb, yb, cb = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            xb, yb, cb = next(train_iter)

        xb, yb, cb = xb.to(DEVICE), yb.to(DEVICE), cb.to(DEVICE)

        logits, loss = model(xb, targets=yb, conditions=cb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Appel périodique au garbage collector pendant l'entraînement
        if iter_num % 1000 == 0:
            gc.collect()

    print("Entraînement terminé !")

    # 7. Génération (Adaptée pour les catégories)
    print("\n--- Génération de Molécules Conditionnelles (Catégorielles) ---")

    # Test de génération simple avant les boucles
    print("\n=== TEST DE GÉNÉRATION SIMPLE ===")
    target_cats_test = [1.0, 1.0, 3.0]
    condition_tensor_test = torch.tensor(target_cats_test, dtype=torch.float32).unsqueeze(0)

    test_smiles = generate_conditional(model, condition_tensor_test, stoi, itos, max_new_tokens=50, top_k=10)
    valid, logp, mw, hbd = check_mol_3_props(test_smiles)
    print(f"Test génération: '{test_smiles}'")
    print(f"  Statut: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd}")

    # Condition 1: Viser la zone "Rule of 3"
    target_cats_1 = [1.0, 1.0, 3.0] # Indices des catégories souhaitées
    condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

    print(f"\nGénération pour Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")
    valid_count = 0
    for i in range(120):  # Réduit à 10 pour les tests
        mol_str = generate_conditional(model, condition_tensor_1, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd = check_mol_3_props(mol_str)
        print(f"  {i+1}. -> '{mol_str}'")
        print(f"     (Valide: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd})")

        if valid == "Valide":
            valid_count += 1
            logp_c = get_category(logp, LOGP_BINS)
            mw_c = get_category(mw, MW_BINS)
            hbd_c = get_category(hbd, HBD_BINS)
            print(f"     Catégories réelles: [LogP({logp_c:.0f}), MW({mw_c:.0f}), HBD({hbd_c:.0f})]")

    print(f"Molécules valides générées: {valid_count}/10")

    # Condition 2: Viser des molécules très lipophiles, grandes, avec peu de donneurs H
    target_cats_2 = [3.0, 3.0, 0.0]
    condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

    print(f"\nGénération pour Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")
    valid_count = 0
    for i in range(120):  # Réduit à 10 pour les tests
        mol_str = generate_conditional(model, condition_tensor_2, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd = check_mol_3_props(mol_str)
        print(f"  {i+1}. -> '{mol_str}'")
        print(f"     (Valide: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd})")

        if valid == "Valide":
            valid_count += 1
            logp_c = get_category(logp, LOGP_BINS)
            mw_c = get_category(mw, MW_BINS)
            hbd_c = get_category(hbd, HBD_BINS)
            print(f"     Catégories réelles: [LogP({logp_c:.0f}), MW({mw_c:.0f}), HBD({hbd_c:.0f})]")

    print(f"Molécules valides générées: {valid_count}/10")

Configuration : Périphérique=cuda, Batch=32, Contexte=128
Construction du vocabulaire à partir de 's_100_str_+1M_fixed.txt'...
Taille vocabulaire : 57
Exemple de mapping : [('<pad>', 0), ('<start>', 1), ('<end>', 2), ('#', 3), ('%', 4), ('(', 5), (')', 6), ('+', 7), ('-', 8), ('.', 9)]

=== TEST DU VOCABULAIRE ===
Test encode/decode: CCO
Encoded: [25, 25, 33]
Decoded: CCO
Match: True
Traitement des données SMILES (calcul des catégories)...
(Cela peut prendre du temps sur tout le dataset)


Calcul des catégories:  18%|████████▋                                        | 126404/710754 [01:34<07:15, 1340.64it/s][15:49:45] Conflicting single bond directions around double bond at index 6.
[15:49:45]   BondStereo set to STEREONONE and single bond directions set to NONE.
Calcul des catégories: 100%|█████████████████████████████████████████████████| 710754/710754 [08:55<00:00, 1326.68it/s]



Nombre total de molécules valides chargées : 710737
Sauvegarde des données traitées dans le cache 'data_cache_categorical.pt'...
Données sauvegardées.
Nombre de paramètres : 0.82 M
Aucun checkpoint trouvé. Démarrage d'un nouvel entraînement.
Début de l'entraînement sur cuda...
Étape 500: perte train 1.1882, perte val 1.1921, temps 16.4s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical.pth...
Étape 1000: perte train 1.0067, perte val 1.0165, temps 30.3s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical.pth...
Étape 1500: perte train 0.9521, perte val 0.9553, temps 45.5s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical.pth...
Étape 2000: perte train 0.9092, perte val 0.9082, temps 59.7s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical.pth...
Étape 2500: perte train 0.8735, perte val 0.8757, temps 74.8s
Sauvegarde du checkpoint dans checkpoints\cond_gpt_categorical.pth...
Étape 3000: perte train 0.8374, perte val 0.8419, temps 88.7s


In [18]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : GÉNÉRATION CONDITIONNELLE avec Métriques Standardisées
Métriques: Validity, Novelty, Uniqueness, Internal Diversity, Desirability
"""

# --- IMPORTS SUPPLÉMENTAIRES ---
import numpy as np
from collections import defaultdict
from rdkit.Chem import AllChem
from rdkit import DataStructs

# --- FONCTIONS DE MÉTRIQUES STANDARDISÉES ---

def calculate_validity(smiles_list):
    """(1) Validity: fraction of chemically valid SMILES among all generated ones"""
    if not smiles_list:
        return 0.0
    valid_count = 0
    for smiles in smiles_list:
        if Chem.MolFromSmiles(smiles) is not None:
            valid_count += 1
    return (valid_count / len(smiles_list)) * 100

def calculate_novelty(valid_smiles_list, training_smiles_set):
    """(2) Novelty: fraction of novel molecules (not in training set) among valid molecules"""
    if not valid_smiles_list:
        return 0.0
    novel_count = 0
    for smiles in valid_smiles_list:
        if smiles not in training_smiles_set:
            novel_count += 1
    return (novel_count / len(valid_smiles_list)) * 100

def calculate_uniqueness(novel_smiles_list):
    """(3) Uniqueness: fraction of unique molecules after eliminating duplicates among novel ones"""
    if not novel_smiles_list:
        return 0.0
    unique_smiles = set(novel_smiles_list)
    return (len(unique_smiles) / len(novel_smiles_list)) * 100

def calculate_internal_diversity(unique_smiles_list):
    """(4) Internal Diversity: chemical diversity using Tanimoto similarity (0-1, higher = more diverse)"""
    if len(unique_smiles_list) < 2:
        return 0.0

    # Generate Morgan fingerprints
    fingerprints = []
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
            fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    # Calculate pairwise Tanimoto similarities
    similarities = []
    for i in range(len(fingerprints)):
        for j in range(i + 1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if not similarities:
        return 0.0

    # Internal diversity = 1 - average similarity
    avg_similarity = np.mean(similarities)
    return 1.0 - avg_similarity

def calculate_desirability(unique_smiles_list, target_categories):
    """(5) Desirability: fraction of valid, novel, unique molecules with desired properties"""
    if not unique_smiles_list:
        return 0.0, 0

    desirable_count = 0
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            # Check if all categories match target
            if (logp_cat == target_categories[0] and
                mw_cat == target_categories[1] and
                hbd_cat == target_categories[2]):
                desirable_count += 1

    desirability_percent = (desirable_count / len(unique_smiles_list)) * 100
    return desirability_percent, desirable_count

def compute_all_metrics(generated_smiles_list, training_smiles_set, target_categories=None):
    """
    Compute all five metrics according to the paper's methodology
    """
    # 1. Validity
    validity_percent = calculate_validity(generated_smiles_list)
    valid_smiles = [s for s in generated_smiles_list if Chem.MolFromSmiles(s) is not None]

    # 2. Novelty
    novelty_percent = calculate_novelty(valid_smiles, training_smiles_set)
    novel_smiles = [s for s in valid_smiles if s not in training_smiles_set]

    # 3. Uniqueness
    uniqueness_percent = calculate_uniqueness(novel_smiles)
    unique_smiles = list(set(novel_smiles))

    # 4. Internal Diversity
    internal_diversity = calculate_internal_diversity(unique_smiles)

    # 5. Desirability (only for biased models)
    desirability_percent = 0.0
    desirable_count = 0
    if target_categories is not None:
        desirability_percent, desirable_count = calculate_desirability(unique_smiles, target_categories)

    metrics = {
        'validity': validity_percent,
        'novelty': novelty_percent,
        'uniqueness': uniqueness_percent,
        'internal_diversity': internal_diversity,
        'desirability': desirability_percent,
        'desirable_count': desirable_count,
        'total_generated': len(generated_smiles_list),
        'valid_count': len(valid_smiles),
        'novel_count': len(novel_smiles),
        'unique_count': len(unique_smiles),
        'unique_molecules': unique_smiles  # Keep the actual unique molecules
    }

    return metrics

def print_metrics_table(metrics, condition_name):
    """Print metrics in a table format similar to the paper"""
    print(f"\n{'='*80}")
    print(f"📊 BENCHMARK RESULTS - {condition_name}")
    print(f"{'='*80}")

    print(f"{'Metric':<15} {'Value':<10} {'Details':<30}")
    print(f"{'-'*55}")
    print(f"{'Validity':<15} {metrics['validity']:<10.1f} {metrics['valid_count']}/{metrics['total_generated']} valid")
    print(f"{'Novelty':<15} {metrics['novelty']:<10.1f} {metrics['novel_count']}/{metrics['valid_count']} novel")
    print(f"{'Uniqueness':<15} {metrics['uniqueness']:<10.1f} {metrics['unique_count']}/{metrics['novel_count']} unique")
    print(f"{'Int Diversity':<15} {metrics['internal_diversity']:<10.3f} Tanimoto-based (0-1)")

    if metrics['desirability'] > 0:
        print(f"{'Desirability':<15} {metrics['desirability']:<10.1f} {metrics['desirable_count']}/{metrics['unique_count']} desirable")

def generate_with_metrics(model, condition_tensor, stoi, itos, training_smiles_set,
                         num_molecules=1000, target_categories=None, **kwargs):
    """
    Génère des molécules avec calcul des métriques standardisées
    """
    generated_smiles = []

    print(f"Génération de {num_molecules} molécules...")

    for i in range(num_molecules):
        smiles = generate_conditional(model, condition_tensor, stoi, itos, **kwargs)
        generated_smiles.append(smiles)

        # Afficher la progression
        if (i + 1) % 100 == 0:
            print(f"  Progression: {i + 1}/{num_molecules} générées")

    # Calcul des métriques standardisées
    metrics = compute_all_metrics(generated_smiles, training_smiles_set, target_categories)

    return metrics

def analyze_property_distribution(unique_molecules, target_categories):
    """Analyse la distribution des propriétés des molécules uniques"""
    if not unique_molecules:
        return

    logp_values = []
    mw_values = []
    hbd_values = []
    category_matches = []

    for smiles in unique_molecules:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_values.append(logp)
            mw_values.append(mw)
            hbd_values.append(hbd)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            match = (logp_cat == target_categories[0] and
                    mw_cat == target_categories[1] and
                    hbd_cat == target_categories[2])
            category_matches.append(match)

    if logp_values:
        print(f"\n📋 Distribution des Propriétés (molécules uniques):")
        print(f"   • LogP: {np.mean(logp_values):.1f} ± {np.std(logp_values):.1f} (min: {np.min(logp_values):.1f}, max: {np.max(logp_values):.1f})")
        print(f"   • MW: {np.mean(mw_values):.0f} ± {np.std(mw_values):.0f} (min: {np.min(mw_values):.0f}, max: {np.max(mw_values):.0f})")
        print(f"   • HBD: {np.mean(hbd_values):.1f} ± {np.std(hbd_values):.1f} (min: {np.min(hbd_values):.0f}, max: {np.max(hbd_values):.0f})")
        print(f"   • Molécules dans la cible: {sum(category_matches)}/{len(category_matches)}")

# --- CHARGEMENT DES DONNÉES D'ENTRAÎNEMENT ---

def load_training_smiles(filepath, sample_size=100000):
    """Load a sample of training SMILES for novelty calculation"""
    print("Chargement des SMILES d'entraînement pour le calcul de nouveauté...")
    training_smiles = set()

    with open(filepath, 'r') as f:
        for i, line in enumerate(f):
            if i >= sample_size:
                break
            smiles = line.strip()
            training_smiles.add(smiles)

    print(f"SMILES d'entraînement chargés: {len(training_smiles)}")
    return training_smiles

# --- MODIFICATION DE LA PARTIE GÉNÉRATION PRINCIPALE ---

print("\n" + "="*80)
print("🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES")
print("="*80)

# Charger les données d'entraînement pour le calcul de nouveauté
training_smiles_set = load_training_smiles('s_100_str_+1M_fixed.txt', sample_size=100000)

# Condition 1: Viser la zone "Rule of 3"
target_cats_1 = [1.0, 1.0, 3.0]
condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")

# Génération avec métriques standardisées
metrics_1 = generate_with_metrics(
    model, condition_tensor_1, stoi, itos, training_smiles_set,
    num_molecules=1000,  # Générer 1000 molécules pour des statistiques robustes
    max_new_tokens=50,
    top_k=10,
    temperature=0.8,
    target_categories=target_cats_1
)

# Affichage des résultats
print_metrics_table(metrics_1, "Rule of 3")
analyze_property_distribution(metrics_1['unique_molecules'], target_cats_1)

# Condition 2: Molécules lipophiles/grandes
target_cats_2 = [3.0, 3.0, 0.0]
condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")

metrics_2 = generate_with_metrics(
    model, condition_tensor_2, stoi, itos, training_smiles_set,
    num_molecules=1000,
    max_new_tokens=60,
    top_k=5,
    temperature=0.7,
    target_categories=target_cats_2
)

print_metrics_table(metrics_2, "Lipophile/Grand")
analyze_property_distribution(metrics_2['unique_molecules'], target_cats_2)

# --- COMPARAISON FINALE ---
print(f"\n{'='*80}")
print(f"📈 COMPARAISON FINALE DES CONDITIONS")
print(f"{'='*80}")

conditions = [
    ("Rule of 3", metrics_1),
    ("Lipophile/Grand", metrics_2)
]

print(f"\n{'Condition':<20} {'Validity':<10} {'Novelty':<10} {'Uniqueness':<12} {'IntDiv':<10} {'Desirability':<12}")
print(f"{'-'*80}")

for name, metrics in conditions:
    print(f"{name:<20} {metrics['validity']:<10.1f} {metrics['novelty']:<10.1f} {metrics['uniqueness']:<12.1f} {metrics['internal_diversity']:<10.3f} {metrics['desirability']:<12.1f}")

# --- SAUVEGARDE DES RÉSULTATS ---
def save_benchmark_results(metrics, filename):
    """Sauvegarde les résultats de benchmark"""
    results = {
        'metrics': metrics,
        'unique_molecules_sample': metrics.get('unique_molecules', [])[:20]  # Sauvegarder un échantillon
    }

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n💾 Résultats sauvegardés dans: {filename}")

# Sauvegarde des résultats
save_benchmark_results(metrics_1, 'benchmark_results_rule_of_3.json')
save_benchmark_results(metrics_2, 'benchmark_results_lipophilic.json')

print(f"\n✅ Benchmark terminé avec analyses complètes!")
print(f"📊 Résumé:")
print(f"   • Rule of 3: {metrics_1['unique_count']} molécules uniques, {metrics_1['desirability']:.1f}% désirables")
print(f"   • Lipophile/Grand: {metrics_2['unique_count']} molécules uniques, {metrics_2['desirability']:.1f}% désirables")


🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES
Chargement des SMILES d'entraînement pour le calcul de nouveauté...
SMILES d'entraînement chargés: 99989

🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1000 générées





📊 BENCHMARK RESULTS - Rule of 3
Metric          Value      Details                       
-------------------------------------------------------
Validity        57.6       576/1000 valid
Novelty         98.1       565/576 novel
Uniqueness      99.8       564/565 unique
Int Diversity   0.849      Tanimoto-based (0-1)
Desirability    24.3       137/564 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 2.3 ± 1.3 (min: -0.9, max: 7.3)
   • MW: 319 ± 50 (min: 179, max: 448)
   • HBD: 2.6 ± 1.1 (min: 0, max: 6)
   • Molécules dans la cible: 137/564

🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/10




📊 BENCHMARK RESULTS - Lipophile/Grand
Metric          Value      Details                       
-------------------------------------------------------
Validity        12.5       125/1000 valid
Novelty         100.0      125/125 novel
Uniqueness      100.0      125/125 unique
Int Diversity   0.833      Tanimoto-based (0-1)
Desirability    20.8       26/125 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 10.0 ± 4.9 (min: 2.1, max: 19.1)
   • MW: 596 ± 109 (min: 403, max: 825)
   • HBD: 0.3 ± 0.5 (min: 0, max: 2)
   • Molécules dans la cible: 26/125

📈 COMPARAISON FINALE DES CONDITIONS

Condition            Validity   Novelty    Uniqueness   IntDiv     Desirability
--------------------------------------------------------------------------------
Rule of 3            57.6       98.1       99.8         0.849      24.3        
Lipophile/Grand      12.5       100.0      100.0        0.833      20.8        

💾 Résultats sauvegardés dans: benchmark_results_rule_of_3.js

In [16]:
import os
if os.path.exists(DATA_CACHE_FILE):
    os.remove(DATA_CACHE_FILE)
    print("🧹 Cache supprimé — il sera régénéré au prochain lancement.")

🧹 Cache supprimé — il sera régénéré au prochain lancement.


In [19]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : GÉNÉRATION CONDITIONNELLE avec Métriques Standardisées
Métriques: Validity, Novelty, Uniqueness, Internal Diversity, Desirability
"""

# --- IMPORTS SUPPLÉMENTAIRES ---
import numpy as np
from collections import defaultdict
from rdkit.Chem import AllChem
from rdkit import DataStructs

# --- FONCTIONS DE MÉTRIQUES STANDARDISÉES ---

def calculate_validity(smiles_list):
    """(1) Validity: fraction of chemically valid SMILES among all generated ones"""
    if not smiles_list:
        return 0.0
    valid_count = 0
    for smiles in smiles_list:
        if Chem.MolFromSmiles(smiles) is not None:
            valid_count += 1
    return (valid_count / len(smiles_list)) * 100

def calculate_novelty(valid_smiles_list, training_smiles_set):
    """(2) Novelty: fraction of novel molecules (not in training set) among valid molecules"""
    if not valid_smiles_list:
        return 0.0
    novel_count = 0
    for smiles in valid_smiles_list:
        if smiles not in training_smiles_set:
            novel_count += 1
    return (novel_count / len(valid_smiles_list)) * 100

def calculate_uniqueness(novel_smiles_list):
    """(3) Uniqueness: fraction of unique molecules after eliminating duplicates among novel ones"""
    if not novel_smiles_list:
        return 0.0
    unique_smiles = set(novel_smiles_list)
    return (len(unique_smiles) / len(novel_smiles_list)) * 100

def calculate_internal_diversity(unique_smiles_list):
    """(4) Internal Diversity: chemical diversity using Tanimoto similarity (0-1, higher = more diverse)"""
    if len(unique_smiles_list) < 2:
        return 0.0

    # Generate Morgan fingerprints
    fingerprints = []
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
            fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    # Calculate pairwise Tanimoto similarities
    similarities = []
    for i in range(len(fingerprints)):
        for j in range(i + 1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if not similarities:
        return 0.0

    # Internal diversity = 1 - average similarity
    avg_similarity = np.mean(similarities)
    return 1.0 - avg_similarity

def calculate_desirability(unique_smiles_list, target_categories):
    """(5) Desirability: fraction of valid, novel, unique molecules with desired properties"""
    if not unique_smiles_list:
        return 0.0, 0

    desirable_count = 0
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            # Check if all categories match target
            if (logp_cat == target_categories[0] and
                mw_cat == target_categories[1] and
                hbd_cat == target_categories[2]):
                desirable_count += 1

    desirability_percent = (desirable_count / len(unique_smiles_list)) * 100
    return desirability_percent, desirable_count

def compute_all_metrics(generated_smiles_list, training_smiles_set, target_categories=None):
    """
    Compute all five metrics according to the paper's methodology
    """
    # 1. Validity
    validity_percent = calculate_validity(generated_smiles_list)
    valid_smiles = [s for s in generated_smiles_list if Chem.MolFromSmiles(s) is not None]

    # 2. Novelty
    novelty_percent = calculate_novelty(valid_smiles, training_smiles_set)
    novel_smiles = [s for s in valid_smiles if s not in training_smiles_set]

    # 3. Uniqueness
    uniqueness_percent = calculate_uniqueness(novel_smiles)
    unique_smiles = list(set(novel_smiles))

    # 4. Internal Diversity
    internal_diversity = calculate_internal_diversity(unique_smiles)

    # 5. Desirability (only for biased models)
    desirability_percent = 0.0
    desirable_count = 0
    if target_categories is not None:
        desirability_percent, desirable_count = calculate_desirability(unique_smiles, target_categories)

    metrics = {
        'validity': validity_percent,
        'novelty': novelty_percent,
        'uniqueness': uniqueness_percent,
        'internal_diversity': internal_diversity,
        'desirability': desirability_percent,
        'desirable_count': desirable_count,
        'total_generated': len(generated_smiles_list),
        'valid_count': len(valid_smiles),
        'novel_count': len(novel_smiles),
        'unique_count': len(unique_smiles),
        'unique_molecules': unique_smiles  # Keep the actual unique molecules
    }

    return metrics

def print_metrics_table(metrics, condition_name):
    """Print metrics in a table format similar to the paper"""
    print(f"\n{'='*80}")
    print(f"📊 BENCHMARK RESULTS - {condition_name}")
    print(f"{'='*80}")

    print(f"{'Metric':<15} {'Value':<10} {'Details':<30}")
    print(f"{'-'*55}")
    print(f"{'Validity':<15} {metrics['validity']:<10.1f} {metrics['valid_count']}/{metrics['total_generated']} valid")
    print(f"{'Novelty':<15} {metrics['novelty']:<10.1f} {metrics['novel_count']}/{metrics['valid_count']} novel")
    print(f"{'Uniqueness':<15} {metrics['uniqueness']:<10.1f} {metrics['unique_count']}/{metrics['novel_count']} unique")
    print(f"{'Int Diversity':<15} {metrics['internal_diversity']:<10.3f} Tanimoto-based (0-1)")

    if metrics['desirability'] > 0:
        print(f"{'Desirability':<15} {metrics['desirability']:<10.1f} {metrics['desirable_count']}/{metrics['unique_count']} desirable")

def generate_with_metrics(model, condition_tensor, stoi, itos, training_smiles_set,
                         num_molecules=1000, target_categories=None, **kwargs):
    """
    Génère des molécules avec calcul des métriques standardisées
    """
    generated_smiles = []

    print(f"Génération de {num_molecules} molécules...")

    for i in range(num_molecules):
        smiles = generate_conditional(model, condition_tensor, stoi, itos, **kwargs)
        generated_smiles.append(smiles)

        # Afficher la progression
        if (i + 1) % 100 == 0:
            print(f"  Progression: {i + 1}/{num_molecules} générées")

    # Calcul des métriques standardisées
    metrics = compute_all_metrics(generated_smiles, training_smiles_set, target_categories)

    return metrics

def analyze_property_distribution(unique_molecules, target_categories):
    """Analyse la distribution des propriétés des molécules uniques"""
    if not unique_molecules:
        return

    logp_values = []
    mw_values = []
    hbd_values = []
    category_matches = []

    for smiles in unique_molecules:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_values.append(logp)
            mw_values.append(mw)
            hbd_values.append(hbd)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            match = (logp_cat == target_categories[0] and
                    mw_cat == target_categories[1] and
                    hbd_cat == target_categories[2])
            category_matches.append(match)

    if logp_values:
        print(f"\n📋 Distribution des Propriétés (molécules uniques):")
        print(f"   • LogP: {np.mean(logp_values):.1f} ± {np.std(logp_values):.1f} (min: {np.min(logp_values):.1f}, max: {np.max(logp_values):.1f})")
        print(f"   • MW: {np.mean(mw_values):.0f} ± {np.std(mw_values):.0f} (min: {np.min(mw_values):.0f}, max: {np.max(mw_values):.0f})")
        print(f"   • HBD: {np.mean(hbd_values):.1f} ± {np.std(hbd_values):.1f} (min: {np.min(hbd_values):.0f}, max: {np.max(hbd_values):.0f})")
        print(f"   • Molécules dans la cible: {sum(category_matches)}/{len(category_matches)}")

# --- CHARGEMENT DES DONNÉES D'ENTRAÎNEMENT ---

def load_training_smiles(filepath, sample_size=100000):
    """Load a sample of training SMILES for novelty calculation"""
    print("Chargement des SMILES d'entraînement pour le calcul de nouveauté...")
    training_smiles = set()

    with open(filepath, 'r') as f:
        for i, line in enumerate(f):
            if i >= sample_size:
                break
            smiles = line.strip()
            training_smiles.add(smiles)

    print(f"SMILES d'entraînement chargés: {len(training_smiles)}")
    return training_smiles

# --- MODIFICATION DE LA PARTIE GÉNÉRATION PRINCIPALE ---

print("\n" + "="*80)
print("🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES")
print("="*80)

# Charger les données d'entraînement pour le calcul de nouveauté
training_smiles_set = load_training_smiles('s_100_str_+1M_fixed.txt', sample_size=100000)

# Condition 1: Viser la zone "Rule of 3"
target_cats_1 = [1.0, 1.0, 3.0]
condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")

# Génération avec métriques standardisées
metrics_1 = generate_with_metrics(
    model, condition_tensor_1, stoi, itos, training_smiles_set,
    num_molecules=1000,  # Générer 1000 molécules pour des statistiques robustes
    max_new_tokens=50,
    top_k=10,
    temperature=0.6,
    target_categories=target_cats_1
)

# Affichage des résultats
print_metrics_table(metrics_1, "Rule of 3")
analyze_property_distribution(metrics_1['unique_molecules'], target_cats_1)

# Condition 2: Molécules lipophiles/grandes
target_cats_2 = [3.0, 3.0, 0.0]
condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")

metrics_2 = generate_with_metrics(
    model, condition_tensor_2, stoi, itos, training_smiles_set,
    num_molecules=1000,
    max_new_tokens=60,
    top_k=5,
    temperature=0.6,
    target_categories=target_cats_2
)

print_metrics_table(metrics_2, "Lipophile/Grand")
analyze_property_distribution(metrics_2['unique_molecules'], target_cats_2)

# --- COMPARAISON FINALE ---
print(f"\n{'='*80}")
print(f"📈 COMPARAISON FINALE DES CONDITIONS")
print(f"{'='*80}")

conditions = [
    ("Rule of 3", metrics_1),
    ("Lipophile/Grand", metrics_2)
]

print(f"\n{'Condition':<20} {'Validity':<10} {'Novelty':<10} {'Uniqueness':<12} {'IntDiv':<10} {'Desirability':<12}")
print(f"{'-'*80}")

for name, metrics in conditions:
    print(f"{name:<20} {metrics['validity']:<10.1f} {metrics['novelty']:<10.1f} {metrics['uniqueness']:<12.1f} {metrics['internal_diversity']:<10.3f} {metrics['desirability']:<12.1f}")

# --- SAUVEGARDE DES RÉSULTATS ---
def save_benchmark_results(metrics, filename):
    """Sauvegarde les résultats de benchmark"""
    results = {
        'metrics': metrics,
        'unique_molecules_sample': metrics.get('unique_molecules', [])[:20]  # Sauvegarder un échantillon
    }

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n💾 Résultats sauvegardés dans: {filename}")

# Sauvegarde des résultats
save_benchmark_results(metrics_1, 'benchmark_results_rule_of_3.json')
save_benchmark_results(metrics_2, 'benchmark_results_lipophilic.json')

print(f"\n✅ Benchmark terminé avec analyses complètes!")
print(f"📊 Résumé:")
print(f"   • Rule of 3: {metrics_1['unique_count']} molécules uniques, {metrics_1['desirability']:.1f}% désirables")
print(f"   • Lipophile/Grand: {metrics_2['unique_count']} molécules uniques, {metrics_2['desirability']:.1f}% désirables")


🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES
Chargement des SMILES d'entraînement pour le calcul de nouveauté...
SMILES d'entraînement chargés: 99989

🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1000 générées





📊 BENCHMARK RESULTS - Rule of 3
Metric          Value      Details                       
-------------------------------------------------------
Validity        73.3       733/1000 valid
Novelty         96.2       705/733 novel
Uniqueness      97.4       687/705 unique
Int Diversity   0.834      Tanimoto-based (0-1)
Desirability    25.2       173/687 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 2.8 ± 1.9 (min: -1.1, max: 10.0)
   • MW: 327 ± 49 (min: 173, max: 458)
   • HBD: 2.6 ± 0.9 (min: 0, max: 5)
   • Molécules dans la cible: 173/687

🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1

[16:19:46] Conflicting single bond directions around double bond at index 6.
[16:19:46]   BondStereo set to STEREONONE and single bond directions set to NONE.
[16:19:46] Conflicting single bond directions around double bond at index 6.
[16:19:46]   BondStereo set to STEREONONE and single bond directions set to NONE.
[16:19:46] Conflicting single bond directions around double bond at index 6.
[16:19:46]   BondStereo set to STEREONONE and single bond directions set to NONE.
[16:19:46] Conflicting single bond directions around double bond at index 6.
[16:19:46]   BondStereo set to STEREONONE and single bond directions set to NONE.
[16:19:47] Conflicting single bond directions around double bond at index 6.
[16:19:47]   BondStereo set to STEREONONE and single bond directions set to NONE.



📊 BENCHMARK RESULTS - Lipophile/Grand
Metric          Value      Details                       
-------------------------------------------------------
Validity        17.4       174/1000 valid
Novelty         100.0      174/174 novel
Uniqueness      99.4       173/174 unique
Int Diversity   0.800      Tanimoto-based (0-1)
Desirability    31.8       55/173 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 12.1 ± 5.0 (min: 0.6, max: 20.2)
   • MW: 643 ± 121 (min: 387, max: 850)
   • HBD: 0.4 ± 0.6 (min: 0, max: 3)
   • Molécules dans la cible: 55/173

📈 COMPARAISON FINALE DES CONDITIONS

Condition            Validity   Novelty    Uniqueness   IntDiv     Desirability
--------------------------------------------------------------------------------
Rule of 3            73.3       96.2       97.4         0.834      25.2        
Lipophile/Grand      17.4       100.0      99.4         0.800      31.8        

💾 Résultats sauvegardés dans: benchmark_results_rule_of_3.js

In [20]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : GÉNÉRATION CONDITIONNELLE avec Métriques Standardisées
Métriques: Validity, Novelty, Uniqueness, Internal Diversity, Desirability
"""

# --- IMPORTS SUPPLÉMENTAIRES ---
import numpy as np
from collections import defaultdict
from rdkit.Chem import AllChem
from rdkit import DataStructs

# --- FONCTIONS DE MÉTRIQUES STANDARDISÉES ---

def calculate_validity(smiles_list):
    """(1) Validity: fraction of chemically valid SMILES among all generated ones"""
    if not smiles_list:
        return 0.0
    valid_count = 0
    for smiles in smiles_list:
        if Chem.MolFromSmiles(smiles) is not None:
            valid_count += 1
    return (valid_count / len(smiles_list)) * 100

def calculate_novelty(valid_smiles_list, training_smiles_set):
    """(2) Novelty: fraction of novel molecules (not in training set) among valid molecules"""
    if not valid_smiles_list:
        return 0.0
    novel_count = 0
    for smiles in valid_smiles_list:
        if smiles not in training_smiles_set:
            novel_count += 1
    return (novel_count / len(valid_smiles_list)) * 100

def calculate_uniqueness(novel_smiles_list):
    """(3) Uniqueness: fraction of unique molecules after eliminating duplicates among novel ones"""
    if not novel_smiles_list:
        return 0.0
    unique_smiles = set(novel_smiles_list)
    return (len(unique_smiles) / len(novel_smiles_list)) * 100

def calculate_internal_diversity(unique_smiles_list):
    """(4) Internal Diversity: chemical diversity using Tanimoto similarity (0-1, higher = more diverse)"""
    if len(unique_smiles_list) < 2:
        return 0.0

    # Generate Morgan fingerprints
    fingerprints = []
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
            fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    # Calculate pairwise Tanimoto similarities
    similarities = []
    for i in range(len(fingerprints)):
        for j in range(i + 1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if not similarities:
        return 0.0

    # Internal diversity = 1 - average similarity
    avg_similarity = np.mean(similarities)
    return 1.0 - avg_similarity

def calculate_desirability(unique_smiles_list, target_categories):
    """(5) Desirability: fraction of valid, novel, unique molecules with desired properties"""
    if not unique_smiles_list:
        return 0.0, 0

    desirable_count = 0
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            # Check if all categories match target
            if (logp_cat == target_categories[0] and
                mw_cat == target_categories[1] and
                hbd_cat == target_categories[2]):
                desirable_count += 1

    desirability_percent = (desirable_count / len(unique_smiles_list)) * 100
    return desirability_percent, desirable_count

def compute_all_metrics(generated_smiles_list, training_smiles_set, target_categories=None):
    """
    Compute all five metrics according to the paper's methodology
    """
    # 1. Validity
    validity_percent = calculate_validity(generated_smiles_list)
    valid_smiles = [s for s in generated_smiles_list if Chem.MolFromSmiles(s) is not None]

    # 2. Novelty
    novelty_percent = calculate_novelty(valid_smiles, training_smiles_set)
    novel_smiles = [s for s in valid_smiles if s not in training_smiles_set]

    # 3. Uniqueness
    uniqueness_percent = calculate_uniqueness(novel_smiles)
    unique_smiles = list(set(novel_smiles))

    # 4. Internal Diversity
    internal_diversity = calculate_internal_diversity(unique_smiles)

    # 5. Desirability (only for biased models)
    desirability_percent = 0.0
    desirable_count = 0
    if target_categories is not None:
        desirability_percent, desirable_count = calculate_desirability(unique_smiles, target_categories)

    metrics = {
        'validity': validity_percent,
        'novelty': novelty_percent,
        'uniqueness': uniqueness_percent,
        'internal_diversity': internal_diversity,
        'desirability': desirability_percent,
        'desirable_count': desirable_count,
        'total_generated': len(generated_smiles_list),
        'valid_count': len(valid_smiles),
        'novel_count': len(novel_smiles),
        'unique_count': len(unique_smiles),
        'unique_molecules': unique_smiles  # Keep the actual unique molecules
    }

    return metrics

def print_metrics_table(metrics, condition_name):
    """Print metrics in a table format similar to the paper"""
    print(f"\n{'='*80}")
    print(f"📊 BENCHMARK RESULTS - {condition_name}")
    print(f"{'='*80}")

    print(f"{'Metric':<15} {'Value':<10} {'Details':<30}")
    print(f"{'-'*55}")
    print(f"{'Validity':<15} {metrics['validity']:<10.1f} {metrics['valid_count']}/{metrics['total_generated']} valid")
    print(f"{'Novelty':<15} {metrics['novelty']:<10.1f} {metrics['novel_count']}/{metrics['valid_count']} novel")
    print(f"{'Uniqueness':<15} {metrics['uniqueness']:<10.1f} {metrics['unique_count']}/{metrics['novel_count']} unique")
    print(f"{'Int Diversity':<15} {metrics['internal_diversity']:<10.3f} Tanimoto-based (0-1)")

    if metrics['desirability'] > 0:
        print(f"{'Desirability':<15} {metrics['desirability']:<10.1f} {metrics['desirable_count']}/{metrics['unique_count']} desirable")

def generate_with_metrics(model, condition_tensor, stoi, itos, training_smiles_set,
                         num_molecules=1000, target_categories=None, **kwargs):
    """
    Génère des molécules avec calcul des métriques standardisées
    """
    generated_smiles = []

    print(f"Génération de {num_molecules} molécules...")

    for i in range(num_molecules):
        smiles = generate_conditional(model, condition_tensor, stoi, itos, **kwargs)
        generated_smiles.append(smiles)

        # Afficher la progression
        if (i + 1) % 100 == 0:
            print(f"  Progression: {i + 1}/{num_molecules} générées")

    # Calcul des métriques standardisées
    metrics = compute_all_metrics(generated_smiles, training_smiles_set, target_categories)

    return metrics

def analyze_property_distribution(unique_molecules, target_categories):
    """Analyse la distribution des propriétés des molécules uniques"""
    if not unique_molecules:
        return

    logp_values = []
    mw_values = []
    hbd_values = []
    category_matches = []

    for smiles in unique_molecules:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_values.append(logp)
            mw_values.append(mw)
            hbd_values.append(hbd)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            match = (logp_cat == target_categories[0] and
                    mw_cat == target_categories[1] and
                    hbd_cat == target_categories[2])
            category_matches.append(match)

    if logp_values:
        print(f"\n📋 Distribution des Propriétés (molécules uniques):")
        print(f"   • LogP: {np.mean(logp_values):.1f} ± {np.std(logp_values):.1f} (min: {np.min(logp_values):.1f}, max: {np.max(logp_values):.1f})")
        print(f"   • MW: {np.mean(mw_values):.0f} ± {np.std(mw_values):.0f} (min: {np.min(mw_values):.0f}, max: {np.max(mw_values):.0f})")
        print(f"   • HBD: {np.mean(hbd_values):.1f} ± {np.std(hbd_values):.1f} (min: {np.min(hbd_values):.0f}, max: {np.max(hbd_values):.0f})")
        print(f"   • Molécules dans la cible: {sum(category_matches)}/{len(category_matches)}")

# --- CHARGEMENT DES DONNÉES D'ENTRAÎNEMENT ---

def load_training_smiles(filepath, sample_size=100000):
    """Load a sample of training SMILES for novelty calculation"""
    print("Chargement des SMILES d'entraînement pour le calcul de nouveauté...")
    training_smiles = set()

    with open(filepath, 'r') as f:
        for i, line in enumerate(f):
            if i >= sample_size:
                break
            smiles = line.strip()
            training_smiles.add(smiles)

    print(f"SMILES d'entraînement chargés: {len(training_smiles)}")
    return training_smiles

# --- MODIFICATION DE LA PARTIE GÉNÉRATION PRINCIPALE ---

print("\n" + "="*80)
print("🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES")
print("="*80)

# Charger les données d'entraînement pour le calcul de nouveauté
training_smiles_set = load_training_smiles('s_100_str_+1M_fixed.txt', sample_size=100000)

# Condition 1: Viser la zone "Rule of 3"
target_cats_1 = [1.0, 1.0, 3.0]
condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")

# Génération avec métriques standardisées
metrics_1 = generate_with_metrics(
    model, condition_tensor_1, stoi, itos, training_smiles_set,
    num_molecules=1000,  # Générer 1000 molécules pour des statistiques robustes
    max_new_tokens=50,
    top_k=10,
    temperature=0.5,
    target_categories=target_cats_1
)

# Affichage des résultats
print_metrics_table(metrics_1, "Rule of 3")
analyze_property_distribution(metrics_1['unique_molecules'], target_cats_1)

# Condition 2: Molécules lipophiles/grandes
target_cats_2 = [3.0, 3.0, 0.0]
condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")

metrics_2 = generate_with_metrics(
    model, condition_tensor_2, stoi, itos, training_smiles_set,
    num_molecules=1000,
    max_new_tokens=60,
    top_k=5,
    temperature=0.5,
    target_categories=target_cats_2
)

print_metrics_table(metrics_2, "Lipophile/Grand")
analyze_property_distribution(metrics_2['unique_molecules'], target_cats_2)

# --- COMPARAISON FINALE ---
print(f"\n{'='*80}")
print(f"📈 COMPARAISON FINALE DES CONDITIONS")
print(f"{'='*80}")

conditions = [
    ("Rule of 3", metrics_1),
    ("Lipophile/Grand", metrics_2)
]

print(f"\n{'Condition':<20} {'Validity':<10} {'Novelty':<10} {'Uniqueness':<12} {'IntDiv':<10} {'Desirability':<12}")
print(f"{'-'*80}")

for name, metrics in conditions:
    print(f"{name:<20} {metrics['validity']:<10.1f} {metrics['novelty']:<10.1f} {metrics['uniqueness']:<12.1f} {metrics['internal_diversity']:<10.3f} {metrics['desirability']:<12.1f}")

# --- SAUVEGARDE DES RÉSULTATS ---
def save_benchmark_results(metrics, filename):
    """Sauvegarde les résultats de benchmark"""
    results = {
        'metrics': metrics,
        'unique_molecules_sample': metrics.get('unique_molecules', [])[:20]  # Sauvegarder un échantillon
    }

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n💾 Résultats sauvegardés dans: {filename}")

# Sauvegarde des résultats
save_benchmark_results(metrics_1, 'benchmark_results_rule_of_3.json')
save_benchmark_results(metrics_2, 'benchmark_results_lipophilic.json')

print(f"\n✅ Benchmark terminé avec analyses complètes!")
print(f"📊 Résumé:")
print(f"   • Rule of 3: {metrics_1['unique_count']} molécules uniques, {metrics_1['desirability']:.1f}% désirables")
print(f"   • Lipophile/Grand: {metrics_2['unique_count']} molécules uniques, {metrics_2['desirability']:.1f}% désirables")


🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES
Chargement des SMILES d'entraînement pour le calcul de nouveauté...
SMILES d'entraînement chargés: 99989

🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1000 générées





📊 BENCHMARK RESULTS - Rule of 3
Metric          Value      Details                       
-------------------------------------------------------
Validity        81.1       811/1000 valid
Novelty         90.9       737/811 novel
Uniqueness      91.7       676/737 unique
Int Diversity   0.824      Tanimoto-based (0-1)
Desirability    24.9       168/676 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 3.1 ± 2.3 (min: -1.4, max: 11.6)
   • MW: 327 ± 53 (min: 202, max: 540)
   • HBD: 2.5 ± 0.9 (min: 0, max: 5)
   • Molécules dans la cible: 168/676

🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1




📊 BENCHMARK RESULTS - Lipophile/Grand
Metric          Value      Details                       
-------------------------------------------------------
Validity        25.2       252/1000 valid
Novelty         100.0      252/252 novel
Uniqueness      98.4       248/252 unique
Int Diversity   0.784      Tanimoto-based (0-1)
Desirability    42.7       106/248 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 14.0 ± 5.7 (min: 1.0, max: 23.7)
   • MW: 684 ± 128 (min: 362, max: 857)
   • HBD: 0.4 ± 0.6 (min: 0, max: 2)
   • Molécules dans la cible: 106/248

📈 COMPARAISON FINALE DES CONDITIONS

Condition            Validity   Novelty    Uniqueness   IntDiv     Desirability
--------------------------------------------------------------------------------
Rule of 3            81.1       90.9       91.7         0.824      24.9        
Lipophile/Grand      25.2       100.0      98.4         0.784      42.7        

💾 Résultats sauvegardés dans: benchmark_results_rule_of_3.

In [2]:
pip install tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
Note: you may need to restart the kernel to use updated packages.


In [22]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : GÉNÉRATION CONDITIONNELLE avec Métriques Standardisées
Métriques: Validity, Novelty, Uniqueness, Internal Diversity, Desirability
"""

import torch
import json
import numpy as np
from collections import defaultdict
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem
from rdkit import DataStructs

# --- FONCTIONS DE MÉTRIQUES STANDARDISÉES ---

def calculate_validity(smiles_list):
    """(1) Validity: fraction of chemically valid SMILES among all generated ones"""
    if not smiles_list:
        return 0.0
    valid_count = 0
    for smiles in smiles_list:
        if Chem.MolFromSmiles(smiles) is not None:
            valid_count += 1
    return (valid_count / len(smiles_list)) * 100

def calculate_novelty(valid_smiles_list, training_smiles_set):
    """(2) Novelty: fraction of novel molecules (not in training set) among valid molecules"""
    if not valid_smiles_list:
        return 0.0
    novel_count = 0
    for smiles in valid_smiles_list:
        if smiles not in training_smiles_set:
            novel_count += 1
    return (novel_count / len(valid_smiles_list)) * 100

def calculate_uniqueness(novel_smiles_list):
    """(3) Uniqueness: fraction of unique molecules after eliminating duplicates among novel ones"""
    if not novel_smiles_list:
        return 0.0
    unique_smiles = set(novel_smiles_list)
    return (len(unique_smiles) / len(novel_smiles_list)) * 100

def calculate_internal_diversity(unique_smiles_list):
    """(4) Internal Diversity: chemical diversity using Tanimoto similarity (0-1, higher = more diverse)"""
    if len(unique_smiles_list) < 2:
        return 0.0

    # Generate Morgan fingerprints
    fingerprints = []
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024)
            fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    # Calculate pairwise Tanimoto similarities
    similarities = []
    for i in range(len(fingerprints)):
        for j in range(i + 1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if not similarities:
        return 0.0

    # Internal diversity = 1 - average similarity
    avg_similarity = np.mean(similarities)
    return 1.0 - avg_similarity

def calculate_desirability(unique_smiles_list, target_categories):
    """(5) Desirability: fraction of valid, novel, unique molecules with desired properties"""
    if not unique_smiles_list:
        return 0.0, 0

    desirable_count = 0
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            # Check if all categories match target
            if (logp_cat == target_categories[0] and
                mw_cat == target_categories[1] and
                hbd_cat == target_categories[2]):
                desirable_count += 1

    desirability_percent = (desirable_count / len(unique_smiles_list)) * 100
    return desirability_percent, desirable_count

def count_molecules_in_condition1(unique_smiles_list):
    """Compte le nombre de molécules qui vérifient logP <= 3, MW <= 480 et HBD <= 3"""
    if not unique_smiles_list:
        return 0, []
    
    condition1_molecules = []
    
    for smiles in unique_smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)
            
            # Vérifier les conditions
            if logp <= 3 and mw <= 480 and hbd <= 3:
                condition1_molecules.append(smiles)
    
    return len(condition1_molecules), condition1_molecules

def analyze_condition1_compliance(metrics, condition_name):
    """Analyse le respect de la condition 1 et affiche les résultats"""
    unique_molecules = metrics.get('unique_molecules', [])
    count_condition1, molecules_condition1 = count_molecules_in_condition1(unique_molecules)
    
    print(f"\n🔬 ANALYSE CONDITION 1 - {condition_name}")
    print(f"   • Molécules vérifiant logP <= 3, MW <= 480, HBD <= 3: {count_condition1}/{len(unique_molecules)}")
    if unique_molecules:
        print(f"   • Pourcentage: {(count_condition1/len(unique_molecules))*100:.1f}%")
    else:
        print(f"   • Pourcentage: N/A")
    
    # Afficher quelques exemples si disponibles
    if molecules_condition1:
        print(f"   • Exemples (premiers 5): {molecules_condition1[:5]}")
    
    return count_condition1, molecules_condition1

def compute_all_metrics(generated_smiles_list, training_smiles_set, target_categories=None):
    """
    Compute all five metrics according to the paper's methodology
    """
    # 1. Validity
    validity_percent = calculate_validity(generated_smiles_list)
    valid_smiles = [s for s in generated_smiles_list if Chem.MolFromSmiles(s) is not None]

    # 2. Novelty
    novelty_percent = calculate_novelty(valid_smiles, training_smiles_set)
    novel_smiles = [s for s in valid_smiles if s not in training_smiles_set]

    # 3. Uniqueness
    uniqueness_percent = calculate_uniqueness(novel_smiles)
    unique_smiles = list(set(novel_smiles))

    # 4. Internal Diversity
    internal_diversity = calculate_internal_diversity(unique_smiles)

    # 5. Desirability (only for biased models)
    desirability_percent = 0.0
    desirable_count = 0
    if target_categories is not None:
        desirability_percent, desirable_count = calculate_desirability(unique_smiles, target_categories)

    metrics = {
        'validity': validity_percent,
        'novelty': novelty_percent,
        'uniqueness': uniqueness_percent,
        'internal_diversity': internal_diversity,
        'desirability': desirability_percent,
        'desirable_count': desirable_count,
        'total_generated': len(generated_smiles_list),
        'valid_count': len(valid_smiles),
        'novel_count': len(novel_smiles),
        'unique_count': len(unique_smiles),
        'unique_molecules': unique_smiles  # Keep the actual unique molecules
    }

    return metrics

def print_metrics_table(metrics, condition_name):
    """Print metrics in a table format similar to the paper"""
    print(f"\n{'='*80}")
    print(f"📊 BENCHMARK RESULTS - {condition_name}")
    print(f"{'='*80}")

    print(f"{'Metric':<15} {'Value':<10} {'Details':<30}")
    print(f"{'-'*55}")
    print(f"{'Validity':<15} {metrics['validity']:<10.1f} {metrics['valid_count']}/{metrics['total_generated']} valid")
    print(f"{'Novelty':<15} {metrics['novelty']:<10.1f} {metrics['novel_count']}/{metrics['valid_count']} novel")
    print(f"{'Uniqueness':<15} {metrics['uniqueness']:<10.1f} {metrics['unique_count']}/{metrics['novel_count']} unique")
    print(f"{'Int Diversity':<15} {metrics['internal_diversity']:<10.3f} Tanimoto-based (0-1)")

    if metrics['desirability'] > 0:
        print(f"{'Desirability':<15} {metrics['desirability']:<10.1f} {metrics['desirable_count']}/{metrics['unique_count']} desirable")

def generate_with_metrics(model, condition_tensor, stoi, itos, training_smiles_set,
                         num_molecules=1000, target_categories=None, **kwargs):
    """
    Génère des molécules avec calcul des métriques standardisées
    """
    generated_smiles = []

    print(f"Génération de {num_molecules} molécules...")

    for i in range(num_molecules):
        smiles = generate_conditional(model, condition_tensor, stoi, itos, **kwargs)
        generated_smiles.append(smiles)

        # Afficher la progression
        if (i + 1) % 100 == 0:
            print(f"  Progression: {i + 1}/{num_molecules} générées")

    # Calcul des métriques standardisées
    metrics = compute_all_metrics(generated_smiles, training_smiles_set, target_categories)

    return metrics

def analyze_property_distribution(unique_molecules, target_categories):
    """Analyse la distribution des propriétés des molécules uniques"""
    if not unique_molecules:
        return

    logp_values = []
    mw_values = []
    hbd_values = []
    category_matches = []

    for smiles in unique_molecules:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            logp = Descriptors.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Descriptors.NumHDonors(mol)

            logp_values.append(logp)
            mw_values.append(mw)
            hbd_values.append(hbd)

            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)

            match = (logp_cat == target_categories[0] and
                    mw_cat == target_categories[1] and
                    hbd_cat == target_categories[2])
            category_matches.append(match)

    if logp_values:
        print(f"\n📋 Distribution des Propriétés (molécules uniques):")
        print(f"   • LogP: {np.mean(logp_values):.1f} ± {np.std(logp_values):.1f} (min: {np.min(logp_values):.1f}, max: {np.max(logp_values):.1f})")
        print(f"   • MW: {np.mean(mw_values):.0f} ± {np.std(mw_values):.0f} (min: {np.min(mw_values):.0f}, max: {np.max(mw_values):.0f})")
        print(f"   • HBD: {np.mean(hbd_values):.1f} ± {np.std(hbd_values):.1f} (min: {np.min(hbd_values):.0f}, max: {np.max(hbd_values):.0f})")
        print(f"   • Molécules dans la cible: {sum(category_matches)}/{len(category_matches)}")

# --- CHARGEMENT DES DONNÉES D'ENTRAÎNEMENT ---

def load_training_smiles(filepath, sample_size=100000):
    """Load a sample of training SMILES for novelty calculation"""
    print("Chargement des SMILES d'entraînement pour le calcul de nouveauté...")
    training_smiles = set()

    with open(filepath, 'r') as f:
        for i, line in enumerate(f):
            if i >= sample_size:
                break
            smiles = line.strip()
            training_smiles.add(smiles)

    print(f"SMILES d'entraînement chargés: {len(training_smiles)}")
    return training_smiles

# --- PARTIE PRINCIPALE ---

print("\n" + "="*80)
print("🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES")
print("="*80)

# Charger les données d'entraînement pour le calcul de nouveauté
training_smiles_set = load_training_smiles('s_100_str_+1M_fixed.txt', sample_size=100000)

# Condition 1: Viser la zone "Rule of 3"
target_cats_1 = [1.0, 1.0, 3.0]
condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)")

# Génération avec métriques standardisées
metrics_1 = generate_with_metrics(
    model, condition_tensor_1, stoi, itos, training_smiles_set,
    num_molecules=1000,
    max_new_tokens=50,
    top_k=10,
    temperature=0.5,
    target_categories=target_cats_1
)

# Affichage des résultats
print_metrics_table(metrics_1, "Rule of 3")
analyze_property_distribution(metrics_1['unique_molecules'], target_cats_1)

# ✅ AJOUT: Analyse spécifique de la condition 1
count_cond1, molecules_cond1 = analyze_condition1_compliance(metrics_1, "Rule of 3")

# Condition 2: Molécules lipophiles/grandes
target_cats_2 = [3.0, 3.0, 0.0]
condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

print(f"\n🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Très lipophile/grand)")

metrics_2 = generate_with_metrics(
    model, condition_tensor_2, stoi, itos, training_smiles_set,
    num_molecules=1000,
    max_new_tokens=60,
    top_k=5,
    temperature=0.5,
    target_categories=target_cats_2
)

print_metrics_table(metrics_2, "Lipophile/Grand")
analyze_property_distribution(metrics_2['unique_molecules'], target_cats_2)

# ✅ AJOUT: Analyse spécifique de la condition 1 pour la condition 2 aussi
count_cond2, molecules_cond2 = analyze_condition1_compliance(metrics_2, "Lipophile/Grand")

# --- COMPARAISON FINALE ---
print(f"\n{'='*80}")
print(f"📈 COMPARAISON FINALE DES CONDITIONS")
print(f"{'='*80}")

conditions = [
    ("Rule of 3", metrics_1, count_cond1),
    ("Lipophile/Grand", metrics_2, count_cond2)
]

print(f"\n{'Condition':<20} {'Validity':<10} {'Novelty':<10} {'Uniqueness':<12} {'IntDiv':<10} {'Desirability':<12} {'Cond1':<10}")
print(f"{'-'*95}")

for name, metrics, cond1_count in conditions:
    unique_count = metrics['unique_count']
    cond1_percent = (cond1_count/unique_count)*100 if unique_count > 0 else 0
    print(f"{name:<20} {metrics['validity']:<10.1f} {metrics['novelty']:<10.1f} {metrics['uniqueness']:<12.1f} {metrics['internal_diversity']:<10.3f} {metrics['desirability']:<12.1f} {cond1_count}/{unique_count} ({cond1_percent:.1f}%)")

# --- SAUVEGARDE DES RÉSULTATS ---
def save_benchmark_results(metrics, filename):
    """Sauvegarde les résultats de benchmark"""
    results = {
        'metrics': metrics,
        'unique_molecules_sample': metrics.get('unique_molecules', [])[:20]  # Sauvegarder un échantillon
    }

    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(f"\n💾 Résultats sauvegardés dans: {filename}")

# Sauvegarde des résultats
save_benchmark_results(metrics_1, 'benchmark_results_rule_of_3.json')
save_benchmark_results(metrics_2, 'benchmark_results_lipophilic.json')

print(f"\n✅ Benchmark terminé avec analyses complètes!")
print(f"📊 Résumé:")
print(f"   • Rule of 3: {metrics_1['unique_count']} molécules uniques, {metrics_1['desirability']:.1f}% désirables, {count_cond1} dans condition 1")
print(f"   • Lipophile/Grand: {metrics_2['unique_count']} molécules uniques, {metrics_2['desirability']:.1f}% désirables, {count_cond2} dans condition 1")


🚀 GÉNÉRATION AVEC MÉTRIQUES STANDARDISÉES
Chargement des SMILES d'entraînement pour le calcul de nouveauté...
SMILES d'entraînement chargés: 99989

🎯 Condition 1: Catégories [LogP(1), MW(1), HBD(3)] (Rule of 3)
Génération de 1000 molécules...
  Progression: 100/1000 générées
  Progression: 200/1000 générées
  Progression: 300/1000 générées
  Progression: 400/1000 générées
  Progression: 500/1000 générées
  Progression: 600/1000 générées
  Progression: 700/1000 générées
  Progression: 800/1000 générées
  Progression: 900/1000 générées
  Progression: 1000/1000 générées





📊 BENCHMARK RESULTS - Rule of 3
Metric          Value      Details                       
-------------------------------------------------------
Validity        72.6       726/1000 valid
Novelty         99.7       724/726 novel
Uniqueness      99.7       722/724 unique
Int Diversity   0.810      Tanimoto-based (0-1)
Desirability    24.4       176/722 desirable

📋 Distribution des Propriétés (molécules uniques):
   • LogP: 2.4 ± 1.2 (min: -1.1, max: 7.5)
   • MW: 339 ± 47 (min: 208, max: 455)
   • HBD: 2.3 ± 0.7 (min: 0, max: 5)
   • Molécules dans la cible: 176/722

🔬 ANALYSE CONDITION 1 - Rule of 3
   • Molécules vérifiant logP <= 3, MW <= 480, HBD <= 3: 479/722
   • Pourcentage: 66.3%
   • Exemples (premiers 5): ['CNC(=O)C1=C(CCN1)C(=O)c2cccc(c2)C(=O)NCCCN3CCCCC3', 'CCOC(=O)C1=C(C)NC(=O)c2ccc(OC)cc2C1=O', 'CC1=C(NC(=O)N2CCCC2)C(=O)NC1=O', 'Cc1ccc(cc1)c2nc(NC(=O)c3ccc(F)cc3)nc(N)n2', 'CC(C)CNC(=O)C1CCC(C1)NC(=O)NCCCN2CCCCCC2']

🎯 Condition 2: Catégories [LogP(3), MW(3), HBD(0)] (Trè




📋 Distribution des Propriétés (molécules uniques):
   • LogP: 9.4 ± 6.6 (min: 0.7, max: 21.2)
   • MW: 599 ± 133 (min: 354, max: 850)
   • HBD: 0.2 ± 0.4 (min: 0, max: 2)
   • Molécules dans la cible: 28/76

🔬 ANALYSE CONDITION 1 - Lipophile/Grand
   • Molécules vérifiant logP <= 3, MW <= 480, HBD <= 3: 9/76
   • Pourcentage: 11.8%
   • Exemples (premiers 5): ['COc1cccc(c1)C(=O)C2CC(CC2)N3CCN(CC3)C(=O)c4cccc(c4)S(=O)(=O)', 'COC(=O)C1=C(C)N(C(=O)c2ccc(C)cc2)N(C1=O)C(=O)N(CC(=O)N(C)C)C', 'CC(C)COC(=O)C(=O)c1ccc(C)c(c1)c2ccc(cc2)S(=O)(=O)N3CCN(CC3)C', 'CC(=O)N1CCC(C1)N2CCN(CC2)C(=O)c3ccc(cc3)c4ccc(cc4)S(=O)(=O)N', 'Cc1ccc(cc1)S(=O)(=O)N2CCN(CC2)c3ccc(cc3)S(=O)(=O)N4CCN(CC4)C']

📈 COMPARAISON FINALE DES CONDITIONS

Condition            Validity   Novelty    Uniqueness   IntDiv     Desirability Cond1     
-----------------------------------------------------------------------------------------------
Rule of 3            72.6       99.7       99.7         0.810      24.4         479/722 (66