DrugGPT Model for molecules generation

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT COMPLET (A √† Z) : TRANSFORMATEUR CONDITIONNEL (Cat√©goriel) - V2 CORRIG√â AVEC DRIVE ET PLOTS
ET TOUS LES OBJECTIFS AJOUT√âS

Objectif : Entra√Æner un mod√®le de type GPT √† g√©n√©rer des SMILES
          en fonction de cat√©gories de propri√©t√©s (LogP, MW, HBD, etc.)
          avec tous les objectifs sp√©cifi√©s.
"""

# --- PARTIE 1 : IMPORTS ET CONFIGURATION ---

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import math
from dataclasses import dataclass
import time
import os
import json
from tqdm import tqdm
import gc # Garbage collector
import matplotlib.pyplot as plt
import numpy as np

# Import pour Google Drive
try:
    from google.colab import drive
    IN_COLAB = True
    print("‚úì Environnement Google Colab d√©tect√©")
except ImportError:
    IN_COLAB = False
    print("‚ö† Environnement local d√©tect√©")

try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from rdkit.Chem import Lipinski
    from rdkit.Chem import Crippen
    from rdkit import rdBase
    rdBase.DisableLog('rdApp.error')
except ImportError:
    print("Erreur : RDKit n'est pas install√©.")
    print("Veuillez l'installer avec : pip install rdkit")
    exit()

# --- Configuration ---
BATCH_SIZE = 32
BLOCK_SIZE = 128

MAX_ITERS = 5000 
EVAL_INTERVAL = 500
EVAL_ITERS = 200
LEARNING_RATE = 3e-4

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


N_EMBD = 128
N_HEAD = 4
N_LAYER = 4
DROPOUT = 0.1
# Augmenter la dimension des conditions pour inclure tous les objectifs
CONDITION_DIM = 10 # LogP_cat, MW_cat, HBD_cat, HBA_cat, RotBonds_cat, AromaticRings, NonAromaticRings, HasFunctionalGroup, R_Value, LipinskiCompliant

# Configuration Google Drive
if IN_COLAB:
    drive.mount('/content/drive')
    DRIVE_PATH = '/content/drive/MyDrive/cond_gpt_model1'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"‚úì Dossier Drive cr√©√© : {DRIVE_PATH}")
else:
    DRIVE_PATH = './local_save'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"‚úì Dossier local cr√©√© : {DRIVE_PATH}")

# Fichiers
DATA_FILE = 's_100_str_+1M_fixed.txt'
VOCAB_FILE = os.path.join(DRIVE_PATH, 'vocab_dataset.json')
DATA_CACHE_FILE = os.path.join(DRIVE_PATH, 'data_cache_categorical_extended.pt')

# Checkpoints
CHECKPOINT_DIR = os.path.join(DRIVE_PATH, 'checkpoints')
CHECKPOINT_FILE = os.path.join(CHECKPOINT_DIR, 'cond_gpt_categorical_extended.pth')
PLOTS_DIR = os.path.join(DRIVE_PATH, 'plots')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

print(f"Configuration : P√©riph√©rique={DEVICE}, Batch={BATCH_SIZE}, Contexte={BLOCK_SIZE}")
torch.manual_seed(1337)

# --- PARTIE 2 : CONSTRUCTION DU VOCABULAIRE ---

print(f"Construction du vocabulaire √† partir de '{DATA_FILE}'...")

# Tokens sp√©ciaux
PAD_TOKEN = '<pad>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'

# Construction du vocabulaire √† partir du dataset
char_set = set()
with open(DATA_FILE, 'r') as f:
    for line in f:
        smiles = line.strip()
        char_set.update(list(smiles))

# Ajouter les tokens sp√©ciaux
special_tokens = [PAD_TOKEN, START_TOKEN, END_TOKEN]
vocabulary = special_tokens + sorted(list(char_set))

# Dictionnaires pour encoder/d√©coder
stoi = { ch:i for i,ch in enumerate(vocabulary) }
itos = { i:ch for i,ch in enumerate(vocabulary) }
vocab_size = len(vocabulary)

# Sauvegarde en JSON
with open(VOCAB_FILE, 'w', encoding='utf-8') as f:
    json.dump({'stoi': stoi, 'itos': itos}, f, indent=2)

print(f"Taille vocabulaire : {vocab_size}")
print("Exemple de mapping :", list(stoi.items())[:10])

# D√©finir les fonctions globales encode/decode
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l if i in itos])

# Test du vocabulaire
print("\n=== TEST DU VOCABULAIRE ===")
test_smiles = "CCO"
print(f"Test encode/decode: {test_smiles}")
encoded = encode(test_smiles)
print(f"Encoded: {encoded}")
decoded = decode(encoded)
print(f"Decoded: {decoded}")
print(f"Match: {test_smiles == decoded}")

# --- PARTIE 3 : PR√âPARATION DES DONN√âES  ---

# Fonction pour assigner une cat√©gorie bas√©e sur les intervalles
def get_category(value, bins):
    for i, upper_bound in enumerate(bins):
        if value <= upper_bound:
            return float(i) # Retourne l'indice de la cat√©gorie en float
    return float(len(bins))

# Fonction pour d√©tecter les groupes fonctionnels sp√©cifiques
def has_functional_group(mol):
    """V√©rifie la pr√©sence d'au moins un des groupes fonctionnels: -OH, -COOR, -COOH, ou -NH2"""
    smarts_patterns = [
        '[OH]',           # -OH
        '[#6]C(=O)[O;H0]', # -COOR (ester)
        'C(=O)[OH]',      # -COOH
        '[NH2]'           # -NH2
    ]

    for pattern in smarts_patterns:
        if mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)):
            return 1.0
    return 0.0

# Fonction pour calculer la valeur R 
def calculate_r_value(mol):
    """Calcule le R-value comme d√©crit dans [13]"""
    try:
        # Cette m√©trique peut √™tre d√©finie diff√©remment selon la r√©f√©rence
        # Pour cet exemple, nous utilisons un ratio simple
        mol_wt = Descriptors.MolWt(mol)
        logp = Crippen.MolLogP(mol)

        if mol_wt > 0:
            # R-value simple bas√© sur la litt√©rature
            r_value = logp / (mol_wt / 100)
            return r_value
        else:
            return 0.0
    except:
        return 0.0

# Fonction pour v√©rifier la compliance Lipinski
def check_lipinski_compliance(mol):
    """V√©rifie la r√®gle de Lipinski compl√®te"""
    logp = Crippen.MolLogP(mol)
    mol_wt = Descriptors.MolWt(mol)
    hbd = Lipinski.NumHDonors(mol)
    hba = Lipinski.NumHAcceptors(mol)
    rotatable_bonds = Lipinski.NumRotatableBonds(mol)

    conditions = [
        logp <= 3,
        mol_wt <= 480,
        hbd <= 3,
        hba <= 3,
        rotatable_bonds <= 3
    ]

    return 1.0 if all(conditions) else 0.0

# D√©finir les bornes sup√©rieures des intervalles
LOGP_BINS = [0.0, 3.0, 5.0] # <=0(0), <=3(1), <=5(2), >5(3)
MW_BINS = [250.0, 480.0, 650.0] # <=250(0), <=480(1), <=650(2), >650(3)
HBD_BINS = [0.0, 1.0, 2.0, 3.0] # =0(0), =1(1), =2(2), =3(3), >3(4)
HBA_BINS = [0.0, 1.0, 2.0, 3.0] # =0(0), =1(1), =2(2), =3(3), >3(4)
ROT_BONDS_BINS = [0.0, 1.0, 2.0, 3.0] # =0(0), =1(1), =2(2), =3(3), >3(4)
R_VALUE_BINS = [0.05, 0.50] # <0.05(0), 0.05-0.50(1), >0.50(2)

def load_and_process_data(filepath, stoi, max_len=BLOCK_SIZE, cache_file=DATA_CACHE_FILE):
    """
    Charge les SMILES, calcule TOUTES les cat√©gories et cr√©e les tenseurs.
    """
    if os.path.exists(cache_file):
        print(f"Chargement des donn√©es cat√©gorielles √©tendues depuis le cache '{cache_file}'...")
        data = torch.load(cache_file)
        print("Donn√©es charg√©es.")
        return data

    print("Traitement des donn√©es SMILES (calcul de TOUTES les cat√©gories)...")
    print("(Cela peut prendre du temps sur tout le dataset)")
    data_processed = []
    pad_idx = stoi[PAD_TOKEN]
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    # Listes pour les statistiques et plots
    all_logp = []
    all_mw = []
    all_hbd = []
    all_hba = []
    all_rot_bonds = []
    all_aromatic_rings = []
    all_non_aromatic_rings = []
    all_functional_groups = []
    all_r_values = []
    all_lipinski = []

    try:
        # Compter les lignes pour tqdm
        num_lines = sum(1 for line in open(filepath, 'r', encoding='utf-8'))
        with open(filepath, 'r') as f:
            for i, line in enumerate(tqdm(f, total=num_lines, desc="Calcul des cat√©gories √©tendues")):
                smiles = line.strip()
                # V√©rifier les caract√®res avant MolFromSmiles
                if not all(c in stoi for c in smiles) or len(smiles) > max_len - 2:
                    continue
                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    continue

                try:
                    # Calcul de toutes les propri√©t√©s
                    logp = Descriptors.MolLogP(mol)
                    mw = Descriptors.MolWt(mol)
                    hbd = Descriptors.NumHDonors(mol)
                    hba = Descriptors.NumHAcceptors(mol)
                    rot_bonds = Descriptors.NumRotatableBonds(mol)

                    # Calcul des cycles
                    aromatic_rings = Lipinski.NumAromaticRings(mol)
                    non_aromatic_rings = Lipinski.NumAliphaticRings(mol)

                    # Groupes fonctionnels
                    functional_group = has_functional_group(mol)

                    # R-value
                    r_value = calculate_r_value(mol)

                    # Compliance Lipinski
                    lipinski_compliant = check_lipinski_compliance(mol)

                    # Collecter les donn√©es pour les plots
                    all_logp.append(logp)
                    all_mw.append(mw)
                    all_hbd.append(hbd)
                    all_hba.append(hba)
                    all_rot_bonds.append(rot_bonds)
                    all_aromatic_rings.append(aromatic_rings)
                    all_non_aromatic_rings.append(non_aromatic_rings)
                    all_functional_groups.append(functional_group)
                    all_r_values.append(r_value)
                    all_lipinski.append(lipinski_compliant)

                    # Conversion en cat√©gories
                    logp_cat = get_category(logp, LOGP_BINS)
                    mw_cat = get_category(mw, MW_BINS)
                    hbd_cat = get_category(hbd, HBD_BINS)
                    hba_cat = get_category(hba, HBA_BINS)
                    rot_bonds_cat = get_category(rot_bonds, ROT_BONDS_BINS)

                    # R-value cat√©gorielle
                    if r_value < 0.05:
                        r_value_cat = 0.0
                    elif r_value <= 0.50:
                        r_value_cat = 1.0
                    else:
                        r_value_cat = 2.0

                    # Vecteur de condition √©tendu (10 dimensions)
                    condition_vector = torch.tensor([
                        logp_cat, mw_cat, hbd_cat, hba_cat, rot_bonds_cat,
                        float(aromatic_rings), float(non_aromatic_rings),
                        functional_group, r_value_cat, lipinski_compliant
                    ], dtype=torch.float32)

                    token_ids = [start_idx] + encode(smiles) + [end_idx]
                    seq_len = len(token_ids)
                    x = torch.full((max_len,), pad_idx, dtype=torch.long)
                    y = torch.full((max_len,), pad_idx, dtype=torch.long)
                    x[:seq_len] = torch.tensor(token_ids, dtype=torch.long)
                    y[:seq_len-1] = torch.tensor(token_ids[1:], dtype=torch.long)

                    data_processed.append((x, y, condition_vector))

                except Exception as e:
                    # G√©rer les erreurs RDKit sp√©cifiques au calcul
                    continue

                # Appel p√©riodique au garbage collector
                if i % 50000 == 0:
                    gc.collect()

        # Cr√©er et sauvegarder les plots des distributions √©tendues
        create_extended_data_distribution_plots(
            all_logp, all_mw, all_hbd, all_hba, all_rot_bonds,
            all_aromatic_rings, all_non_aromatic_rings,
            all_functional_groups, all_r_values, all_lipinski
        )

    except FileNotFoundError:
        print(f"ERREUR: Le fichier de donn√©es '{filepath}' est introuvable.")
        exit()
    except MemoryError:
        print("\nERREUR: Manque de m√©moire pour traiter tout le dataset.")
        print("Vous devrez peut-√™tre utiliser un √©chantillon plus petit ou augmenter la RAM.")
        exit()

    print(f"\nNombre total de mol√©cules valides charg√©es : {len(data_processed)}")
    print(f"Sauvegarde des donn√©es trait√©es dans le cache '{cache_file}'...")
    torch.save(data_processed, cache_file)
    print("Donn√©es sauvegard√©es.")

    return data_processed

def create_extended_data_distribution_plots(logp_data, mw_data, hbd_data, hba_data, rot_bonds_data,
                                          aromatic_rings_data, non_aromatic_rings_data,
                                          functional_groups_data, r_values_data, lipinski_data):
    """Cr√©e et sauvegarde les plots de distribution des donn√©es √©tendues"""
    print("Cr√©ation des plots de distribution des donn√©es √©tendues...")

    fig, axes = plt.subplots(3, 3, figsize=(18, 15))

    # Plot LogP
    axes[0, 0].hist(logp_data, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].set_xlabel('LogP')
    axes[0, 0].set_ylabel('Fr√©quence')
    axes[0, 0].set_title('Distribution de LogP')
    axes[0, 0].grid(True, alpha=0.3)

    # Plot MW
    axes[0, 1].hist(mw_data, bins=50, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0, 1].set_xlabel('Poids Mol√©culaire (MW)')
    axes[0, 1].set_ylabel('Fr√©quence')
    axes[0, 1].set_title('Distribution du Poids Mol√©culaire')
    axes[0, 1].grid(True, alpha=0.3)

    # Plot HBD
    axes[0, 2].hist(hbd_data, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[0, 2].set_xlabel('Nombre de Donneurs H (HBD)')
    axes[0, 2].set_ylabel('Fr√©quence')
    axes[0, 2].set_title('Distribution des Donneurs H')
    axes[0, 2].grid(True, alpha=0.3)

    # Plot HBA
    axes[1, 0].hist(hba_data, bins=20, alpha=0.7, color='gold', edgecolor='black')
    axes[1, 0].set_xlabel('Nombre d\'Accepteurs H (HBA)')
    axes[1, 0].set_ylabel('Fr√©quence')
    axes[1, 0].set_title('Distribution des Accepteurs H')
    axes[1, 0].grid(True, alpha=0.3)

    # Plot Rotatable Bonds
    axes[1, 1].hist(rot_bonds_data, bins=20, alpha=0.7, color='violet', edgecolor='black')
    axes[1, 1].set_xlabel('Liaisons Rotatives')
    axes[1, 1].set_ylabel('Fr√©quence')
    axes[1, 1].set_title('Distribution des Liaisons Rotatives')
    axes[1, 1].grid(True, alpha=0.3)

    # Plot Aromatic Rings
    axes[1, 2].hist(aromatic_rings_data, bins=10, alpha=0.7, color='orange', edgecolor='black')
    axes[1, 2].set_xlabel('Cycles Aromatiques')
    axes[1, 2].set_ylabel('Fr√©quence')
    axes[1, 2].set_title('Distribution des Cycles Aromatiques')
    axes[1, 2].grid(True, alpha=0.3)

    # Plot Non-Aromatic Rings
    axes[2, 0].hist(non_aromatic_rings_data, bins=10, alpha=0.7, color='cyan', edgecolor='black')
    axes[2, 0].set_xlabel('Cycles Non-Aromatiques')
    axes[2, 0].set_ylabel('Fr√©quence')
    axes[2, 0].set_title('Distribution des Cycles Non-Aromatiques')
    axes[2, 0].grid(True, alpha=0.3)

    # Plot R-values
    axes[2, 1].hist(r_values_data, bins=50, alpha=0.7, color='brown', edgecolor='black')
    axes[2, 1].set_xlabel('R-value')
    axes[2, 1].set_ylabel('Fr√©quence')
    axes[2, 1].set_title('Distribution des R-values')
    axes[2, 1].grid(True, alpha=0.3)

    # Plot Lipinski Compliance
    lipinski_counts = [sum(lipinski_data), len(lipinski_data) - sum(lipinski_data)]
    axes[2, 2].bar(['Compliant', 'Non-Compliant'], lipinski_counts,
                   color=['green', 'red'], alpha=0.7)
    axes[2, 2].set_ylabel('Nombre de Mol√©cules')
    axes[2, 2].set_title('Compliance Lipinski')
    axes[2, 2].grid(True, alpha=0.3)

    plt.tight_layout()
    plot_path = os.path.join(PLOTS_DIR, 'extended_data_distributions.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"‚úì Plots de distribution √©tendus sauvegard√©s dans : {plot_path}")

class SMILESDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

# --- PARTIE 4 : ARCHITECTURE DU MOD√àLE (Avec dimension de condition √©tendue) ---

@dataclass
class GPTConfig:
    block_size: int = BLOCK_SIZE
    vocab_size: int = vocab_size # Utilise la variable globale
    n_layer: int = N_LAYER
    n_head: int = N_HEAD
    n_embd: int = N_EMBD
    dropout: float = DROPOUT
    condition_dim: int = CONDITION_DIM # Maintenant 10 dimensions

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


# --- PARTIE 5 : LE TRANSFORMATEUR CONDITIONNEL (Avec condition √©tendue) ---

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))

        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"S√©quence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)

        assert conditions is not None, "Les conditions doivent √™tre fournies !"
        cond_emb = self.condition_projector(conditions)

        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=stoi[PAD_TOKEN])
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss


# --- PARTIE 6 : FONCTIONS DE CHECKPOINT ET D'√âVALUATION (Avec Plots) ---

def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, filepath):
    print(f"Sauvegarde du checkpoint dans {filepath}...")
    torch.save({
        'iter_num': iter_num,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'config': config,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    if not os.path.exists(filepath):
        print("Aucun checkpoint trouv√©. D√©marrage d'un nouvel entra√Ænement.")
        return 0, float('inf')

    print(f"Chargement du checkpoint depuis {filepath}...")
    checkpoint = torch.load(filepath, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    print(f"Reprise de l'entra√Ænement √† l'it√©ration {iter_num} (meilleure perte val: {best_val_loss:.4f})")
    return iter_num, best_val_loss

@torch.no_grad()
def estimate_loss(model, train_loader, val_loader, eval_iters=EVAL_ITERS):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        loader = train_loader if split == 'train' else val_loader
        losses = torch.zeros(eval_iters)
        loader_iter = iter(loader)
        for k in range(eval_iters):
            try:
                x, y, c = next(loader_iter)
            except StopIteration:
                loader_iter = iter(loader)
                x, y, c = next(loader_iter)

            x, y, c = x.to(DEVICE), y.to(DEVICE), c.to(DEVICE)
            logits, loss = model(x, y, c)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Variables globales pour suivre l'historique des pertes
train_losses = []
val_losses = []
iterations = []

def plot_training_progress(iter_num, train_loss, val_loss):
    """Cr√©e et sauvegarde le plot de progression de l'entra√Ænement"""
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    iterations.append(iter_num)

    plt.figure(figsize=(10, 6))
    plt.plot(iterations, train_losses, 'b-', label='Perte d\'entra√Ænement', alpha=0.7)
    plt.plot(iterations, val_losses, 'r-', label='Perte de validation', alpha=0.7)
    plt.xlabel('It√©rations')
    plt.ylabel('Perte')
    plt.title('Progression de l\'Entra√Ænement')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Sauvegarder le plot
    plot_path = os.path.join(PLOTS_DIR, f'training_progress_iter_{iter_num}.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"‚úì Plot d'entra√Ænement sauvegard√© : {plot_path}")

# --- FONCTION DE G√âN√âRATION CORRIG√âE (Pour conditions √©tendues) ---
@torch.no_grad()
def generate_conditional(model, condition_tensor, stoi, itos, max_new_tokens=100, temperature=1.0, top_k=None):
    """
    G√©n√®re une s√©quence SMILES √† partir d'un tenseur de condition CAT√âGORIELLE √âTENDUE.
    CORRECTION : Exclut le token <start> du d√©codage final.
    """
    model.eval()
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
    condition_tensor = condition_tensor.to(DEVICE)

    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]

        # Passe le tenseur de cat√©gories √©tendues directement
        logits, _ = model(idx_cond, conditions=condition_tensor)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        if idx_next.item() == end_idx:
            break

        idx = torch.cat((idx, idx_next), dim=1)

    model.train()

    # CORRECTION : Exclure le token de d√©but et convertir les tokens en SMILES
    generated_tokens = idx[0].tolist()

    # Exclure le token <start> et s'arr√™ter au token <end> si pr√©sent
    if len(generated_tokens) > 1:
        # Commencer √† partir du token apr√®s <start>
        tokens_to_decode = generated_tokens[1:]

        # S'arr√™ter au token <end> s'il est pr√©sent
        if end_idx in tokens_to_decode:
            end_pos = tokens_to_decode.index(end_idx)
            tokens_to_decode = tokens_to_decode[:end_pos]
    else:
        tokens_to_decode = []

    generated_smiles = decode(tokens_to_decode)

    return generated_smiles

# --- FONCTION DE V√âRIFICATION DES PROPRI√âT√âS √âTENDUES ---
def check_mol_all_props(smiles):
    """ V√©rifie TOUTES les propri√©t√©s r√©elles. """
    if not smiles:  # V√©rifier si la cha√Æne est vide
        return "Vide", 0.0, 0.0, 0, 0, 0, 0, 0, 0, 0.0, 0

    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalide", 0.0, 0.0, 0, 0, 0, 0, 0, 0, 0.0, 0

    # Calcul de toutes les propri√©t√©s
    logp = Descriptors.MolLogP(mol)
    mw = Descriptors.MolWt(mol)
    hbd = Descriptors.NumHDonors(mol)
    hba = Descriptors.NumHAcceptors(mol)
    rot_bonds = Descriptors.NumRotatableBonds(mol)
    aromatic_rings = Lipinski.NumAromaticRings(mol)
    non_aromatic_rings = Lipinski.NumAliphaticRings(mol)
    functional_group = has_functional_group(mol)
    r_value = calculate_r_value(mol)
    lipinski_compliant = check_lipinski_compliance(mol)

    return "Valide", logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski_compliant

def plot_extended_generation_results(condition_results, condition_name):
    """Cr√©e un plot des r√©sultats de g√©n√©ration √©tendus pour une condition donn√©e"""
    valid_molecules = [result for result in condition_results if result['status'] == 'Valide']

    if not valid_molecules:
        print(f"Aucune mol√©cule valide g√©n√©r√©e pour {condition_name}")
        return

    fig, axes = plt.subplots(3, 3, figsize=(18, 15))

    # Plot 1: Distribution des LogP g√©n√©r√©s
    logp_values = [result['logp'] for result in valid_molecules]
    axes[0, 0].hist(logp_values, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].set_xlabel('LogP')
    axes[0, 0].set_ylabel('Nombre de Mol√©cules')
    axes[0, 0].set_title(f'Distribution LogP - {condition_name}')
    axes[0, 0].axvline(x=3, color='red', linestyle='--', label='LogP ‚â§ 3')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
        # Plot 2: Distribution des MW g√©n√©r√©s
    mw_values = [result['mw'] for result in valid_molecules]
    axes[0, 1].hist(mw_values, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0, 1].set_xlabel('Poids Mol√©culaire')
    axes[0, 1].set_ylabel('Nombre de Mol√©cules')
    axes[0, 1].set_title(f'Distribution MW - {condition_name}')
    axes[0, 1].axvline(x=480, color='red', linestyle='--', label='MW ‚â§ 480')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Distribution des HBD g√©n√©r√©s
    hbd_values = [result['hbd'] for result in valid_molecules]
    axes[0, 2].hist(hbd_values, bins=10, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[0, 2].set_xlabel('HBD')
    axes[0, 2].set_ylabel('Nombre de Mol√©cules')
    axes[0, 2].set_title(f'Distribution HBD - {condition_name}')
    axes[0, 2].axvline(x=3, color='red', linestyle='--', label='HBD ‚â§ 3')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)

    # Plot 4: Distribution des HBA g√©n√©r√©s
    hba_values = [result['hba'] for result in valid_molecules]
    axes[1, 0].hist(hba_values, bins=10, alpha=0.7, color='gold', edgecolor='black')
    axes[1, 0].set_xlabel('HBA')
    axes[1, 0].set_ylabel('Nombre de Mol√©cules')
    axes[1, 0].set_title(f'Distribution HBA - {condition_name}')
    axes[1, 0].axvline(x=3, color='red', linestyle='--', label='HBA ‚â§ 3')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Plot 5: Distribution des liaisons rotatives
    rot_bonds_values = [result['rot_bonds'] for result in valid_molecules]
    axes[1, 1].hist(rot_bonds_values, bins=10, alpha=0.7, color='violet', edgecolor='black')
    axes[1, 1].set_xlabel('Liaisons Rotatives')
    axes[1, 1].set_ylabel('Nombre de Mol√©cules')
    axes[1, 1].set_title(f'Distribution Liaisons Rotatives - {condition_name}')
    axes[1, 1].axvline(x=3, color='red', linestyle='--', label='RotBonds ‚â§ 3')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    # Plot 6: Distribution des cycles aromatiques
    aromatic_rings_values = [result['aromatic_rings'] for result in valid_molecules]
    axes[1, 2].hist(aromatic_rings_values, bins=10, alpha=0.7, color='orange', edgecolor='black')
    axes[1, 2].set_xlabel('Cycles Aromatiques')
    axes[1, 2].set_ylabel('Nombre de Mol√©cules')
    axes[1, 2].set_title(f'Cycles Aromatiques - {condition_name}')
    axes[1, 2].grid(True, alpha=0.3)

    # Plot 7: Distribution des cycles non-aromatiques
    non_aromatic_rings_values = [result['non_aromatic_rings'] for result in valid_molecules]
    axes[2, 0].hist(non_aromatic_rings_values, bins=10, alpha=0.7, color='cyan', edgecolor='black')
    axes[2, 0].set_xlabel('Cycles Non-Aromatiques')
    axes[2, 0].set_ylabel('Nombre de Mol√©cules')
    axes[2, 0].set_title(f'Cycles Non-Aromatiques - {condition_name}')
    axes[2, 0].grid(True, alpha=0.3)

    # Plot 8: Distribution des R-values
    r_value_values = [result['r_value'] for result in valid_molecules]
    axes[2, 1].hist(r_value_values, bins=20, alpha=0.7, color='brown', edgecolor='black')
    axes[2, 1].set_xlabel('R-value')
    axes[2, 1].set_ylabel('Nombre de Mol√©cules')
    axes[2, 1].set_title(f'Distribution R-value - {condition_name}')
    axes[2, 1].axvline(x=0.05, color='red', linestyle='--', label='R ‚â• 0.05')
    axes[2, 1].axvline(x=0.50, color='red', linestyle='--', label='R ‚â§ 0.50')
    axes[2, 1].legend()
    axes[2, 1].grid(True, alpha=0.3)

    # Plot 9: Compliance Lipinski
    lipinski_counts = [sum(1 for r in valid_molecules if r['lipinski_compliant'] == 1),
                      sum(1 for r in valid_molecules if r['lipinski_compliant'] == 0)]
    axes[2, 2].bar(['Compliant', 'Non-Compliant'], lipinski_counts,
                   color=['green', 'red'], alpha=0.7)
    axes[2, 2].set_ylabel('Nombre de Mol√©cules')
    axes[2, 2].set_title(f'Compliance Lipinski - {condition_name}')
    axes[2, 2].grid(True, alpha=0.3)

    plt.tight_layout()
    plot_path = os.path.join(PLOTS_DIR, f'extended_generation_{condition_name.replace(" ", "_")}.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"‚úì Plot de g√©n√©ration √©tendu sauvegard√© : {plot_path}")

# --- PARTIE 7 : SCRIPT PRINCIPAL D'EX√âCUTION ---

if __name__ == "__main__":

    # 1. Vocabulaire - stoi, itos, vocab_size sont d√©j√† globales

    # 2. Donn√©es (charge ou traite les donn√©es cat√©gorielles √©tendues)
    full_data = load_and_process_data(DATA_FILE, stoi, cache_file=DATA_CACHE_FILE)

    # 3. DataLoaders
    train_size = int(0.9 * len(full_data))
    val_size = len(full_data) - train_size
    train_data, val_data = torch.utils.data.random_split(full_data, [train_size, val_size])

    train_loader = DataLoader(SMILESDataset(train_data), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(SMILESDataset(val_data), batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    # 4. Mod√®le et Optimiseur
    config = GPTConfig(vocab_size=vocab_size, condition_dim=CONDITION_DIM)
    model = ConditionalDrugGPT(config)
    model.to(DEVICE)

    print(f"Nombre de param√®tres : {sum(p.numel() for p in model.parameters())/1e6:.2f} M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # 5. Chargement du Checkpoint
    start_iter, best_val_loss = load_checkpoint(CHECKPOINT_FILE, model, optimizer)

    # 6. Boucle d'entra√Ænement
    print(f"D√©but de l'entra√Ænement sur {DEVICE}...")
    start_time = time.time()
    train_iter = iter(train_loader)

    for iter_num in range(start_iter, MAX_ITERS):

        if iter_num > 0 and (iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1):
            losses = estimate_loss(model, train_loader, val_loader)
            elapsed = time.time() - start_time
            print(f"√âtape {iter_num}: perte train {losses['train']:.4f}, perte val {losses['val']:.4f}, temps {elapsed:.1f}s")

            # Mettre √† jour le plot de progression
            plot_training_progress(iter_num, losses['train'], losses['val'])

            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']
                save_checkpoint(model, optimizer, iter_num, best_val_loss, config, CHECKPOINT_FILE)

        try:
            xb, yb, cb = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            xb, yb, cb = next(train_iter)

        xb, yb, cb = xb.to(DEVICE), yb.to(DEVICE), cb.to(DEVICE)

        logits, loss = model(xb, targets=yb, conditions=cb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Appel p√©riodique au garbage collector pendant l'entra√Ænement
        if iter_num % 1000 == 0:
            gc.collect()

    print("Entra√Ænement termin√© !")

    # 7. G√©n√©ration avec TOUS les objectifs
    print("\n--- G√©n√©ration de Mol√©cules avec TOUS les Objectifs ---")

    # Test de g√©n√©ration simple avant les boucles
    print("\n=== TEST DE G√âN√âRATION SIMPLE ===")

    # Condition de test: LogP‚â§3, MW‚â§480, HBD‚â§3, HBA‚â§3, RotBonds‚â§3,
    # 2 cycles aromatiques, 1 cycle non-aromatique, avec groupe fonctionnel,
    # R-value [0.05-0.50], Lipinski compliant
    target_cats_test = [
        1.0,  # LogP ‚â§ 3 (cat√©gorie 1)
        1.0,  # MW ‚â§ 480 (cat√©gorie 1)
        3.0,  # HBD ‚â§ 3 (cat√©gorie 3)
        3.0,  # HBA ‚â§ 3 (cat√©gorie 3)
        3.0,  # RotBonds ‚â§ 3 (cat√©gorie 3)
        2.0,  # 2 cycles aromatiques
        1.0,  # 1 cycle non-aromatique
        1.0,  # Avec groupe fonctionnel
        1.0,  # R-value dans [0.05-0.50]
        1.0   # Lipinski compliant
    ]

    condition_tensor_test = torch.tensor(target_cats_test, dtype=torch.float32).unsqueeze(0)

    test_smiles = generate_conditional(model, condition_tensor_test, stoi, itos, max_new_tokens=50, top_k=10)
    valid, logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski = check_mol_all_props(test_smiles)
    print(f"Test g√©n√©ration: '{test_smiles}'")
    print(f"  Statut: {valid}")
    print(f"  LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd}, HBA: {hba}")
    print(f"  RotBonds: {rot_bonds}, AromaticRings: {aromatic_rings}, NonAromaticRings: {non_aromatic_rings}")
    print(f"  FunctionalGroup: {functional_group}, R-value: {r_value:.3f}, Lipinski: {lipinski}")

    # Condition 1: Objectif 1 - LogP ‚â§ 3 seul
    print(f"\n{'='*60}")
    print("OBJECTIF 1: LogP ‚â§ 3 (Rule of Three - composante unique)")
    print(f"{'='*60}")

    target_cats_1 = [
        1.0,  # LogP ‚â§ 3 (cat√©gorie 1)
        0.0,  # MW quelconque
        0.0,  # HBD quelconque
        0.0,  # HBA quelconque
        0.0,  # RotBonds quelconque
        0.0,  # Cycles aromatiques quelconques
        0.0,  # Cycles non-aromatiques quelconques
        0.0,  # Groupe fonctionnel quelconque
        0.0,  # R-value quelconque
        0.0   # Lipinski quelconque
    ]

    condition_tensor_1 = torch.tensor(target_cats_1, dtype=torch.float32).unsqueeze(0)

    print(f"G√©n√©ration pour LogP ‚â§ 3 uniquement")
    condition_1_results = []
    valid_count = 0
    for i in range(120):
        mol_str = generate_conditional(model, condition_tensor_1, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski = check_mol_all_props(mol_str)

        result = {
            'smiles': mol_str,
            'status': valid,
            'logp': logp,
            'mw': mw,
            'hbd': hbd,
            'hba': hba,
            'rot_bonds': rot_bonds,
            'aromatic_rings': aromatic_rings,
            'non_aromatic_rings': non_aromatic_rings,
            'functional_group': functional_group,
            'r_value': r_value,
            'lipinski_compliant': lipinski
        }
        condition_1_results.append(result)

        if i < 5:  # Afficher seulement les 5 premiers
            print(f"  {i+1}. -> '{mol_str}'")
            print(f"     (Valide: {valid}, LogP: {logp:.1f})")

        if valid == "Valide":
            valid_count += 1

    print(f"Mol√©cules valides g√©n√©r√©es: {valid_count}/120")
    print(f"Parmi les valides, LogP ‚â§ 3: {sum(1 for r in condition_1_results if r['status'] == 'Valide' and r['logp'] <= 3)}")

    # Cr√©er le plot pour la condition 1
    plot_extended_generation_results(condition_1_results, "Objectif 1 - LogP ‚â§ 3")

    # Condition 2: Objectif 2 - Multiples objectifs structuraux [13]
    print(f"\n{'='*60}")
    print("OBJECTIF 2: Multiples objectifs structuraux [13]")
    print("  (i) 2 cycles aromatiques, 1 cycle non-aromatique")
    print("  (ii) Au moins un groupe fonctionnel: -OH, -COOR, -COOH, ou -NH2")
    print("  (iii) R-value dans [0.05‚Äì0.50]")
    print(f"{'='*60}")

    target_cats_2 = [
        0.0,  # LogP quelconque
        0.0,  # MW quelconque
        0.0,  # HBD quelconque
        0.0,  # HBA quelconque
        0.0,  # RotBonds quelconque
        2.0,  # 2 cycles aromatiques
        1.0,  # 1 cycle non-aromatique
        1.0,  # Avec groupe fonctionnel
        1.0,  # R-value dans [0.05-0.50]
        0.0   # Lipinski quelconque
    ]

    condition_tensor_2 = torch.tensor(target_cats_2, dtype=torch.float32).unsqueeze(0)

    print(f"G√©n√©ration pour objectifs structuraux [13]")
    condition_2_results = []
    valid_count = 0
    for i in range(120):
        mol_str = generate_conditional(model, condition_tensor_2, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski = check_mol_all_props(mol_str)

        result = {
            'smiles': mol_str,
            'status': valid,
            'logp': logp,
            'mw': mw,
            'hbd': hbd,
            'hba': hba,
            'rot_bonds': rot_bonds,
            'aromatic_rings': aromatic_rings,
            'non_aromatic_rings': non_aromatic_rings,
            'functional_group': functional_group,
            'r_value': r_value,
            'lipinski_compliant': lipinski
        }
        condition_2_results.append(result)

        if i < 5:  # Afficher seulement les 5 premiers
            print(f"  {i+1}. -> '{mol_str}'")
            print(f"     (Valide: {valid}, Aromatic: {aromatic_rings}, NonAromatic: {non_aromatic_rings}, Functional: {functional_group}, R-value: {r_value:.3f})")

        if valid == "Valide":
            valid_count += 1

    # Calcul des statistiques pour l'objectif 2
    valid_mols_2 = [r for r in condition_2_results if r['status'] == 'Valide']
    if valid_mols_2:
        target_aromatic = sum(1 for r in valid_mols_2 if r['aromatic_rings'] == 2)
        target_non_aromatic = sum(1 for r in valid_mols_2 if r['non_aromatic_rings'] == 1)
        target_functional = sum(1 for r in valid_mols_2 if r['functional_group'] == 1)
        target_r_value = sum(1 for r in valid_mols_2 if 0.05 <= r['r_value'] <= 0.50)

        print(f"Mol√©cules valides g√©n√©r√©es: {valid_count}/120")
        print(f"Cycles aromatiques = 2: {target_aromatic}/{valid_count}")
        print(f"Cycles non-aromatiques = 1: {target_non_aromatic}/{valid_count}")
        print(f"Avec groupe fonctionnel: {target_functional}/{valid_count}")
        print(f"R-value dans [0.05-0.50]: {target_r_value}/{valid_count}")

    # Cr√©er le plot pour la condition 2
    plot_extended_generation_results(condition_2_results, "Objectif 2 - Structuraux [13]")

    # Condition 3: Objectif 3 - R√®gle compl√®te de Lipinski
    print(f"\n{'='*60}")
    print("OBJECTIF 3: R√®gle compl√®te de Lipinski (Rule of Three)")
    print("  - LogP ‚â§ 3")
    print("  - MW ‚â§ 480 g/mol")
    print("  - HBD ‚â§ 3")
    print("  - HBA ‚â§ 3")
    print("  - Liaisons rotatives ‚â§ 3")
    print(f"{'='*60}")

    target_cats_3 = [
        1.0,  # LogP ‚â§ 3
        1.0,  # MW ‚â§ 480
        3.0,  # HBD ‚â§ 3
        3.0,  # HBA ‚â§ 3
        3.0,  # RotBonds ‚â§ 3
        0.0,  # Cycles aromatiques quelconques
        0.0,  # Cycles non-aromatiques quelconques
        0.0,  # Groupe fonctionnel quelconque
        0.0,  # R-value quelconque
        1.0   # Lipinski compliant
    ]

    condition_tensor_3 = torch.tensor(target_cats_3, dtype=torch.float32).unsqueeze(0)

    print(f"G√©n√©ration pour r√®gle compl√®te de Lipinski")
    condition_3_results = []
    valid_count = 0
    for i in range(120):
        mol_str = generate_conditional(model, condition_tensor_3, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski = check_mol_all_props(mol_str)

        result = {
            'smiles': mol_str,
            'status': valid,
            'logp': logp,
            'mw': mw,
            'hbd': hbd,
            'hba': hba,
            'rot_bonds': rot_bonds,
            'aromatic_rings': aromatic_rings,
            'non_aromatic_rings': non_aromatic_rings,
            'functional_group': functional_group,
            'r_value': r_value,
            'lipinski_compliant': lipinski
        }
        condition_3_results.append(result)

        if i < 5:  # Afficher seulement les 5 premiers
            print(f"  {i+1}. -> '{mol_str}'")
            print(f"     (Valide: {valid}, LogP: {logp:.1f}, MW: {mw:.0f}, HBD: {hbd}, HBA: {hba}, RotBonds: {rot_bonds})")

        if valid == "Valide":
            valid_count += 1

    # Calcul des statistiques pour l'objectif 3
    valid_mols_3 = [r for r in condition_3_results if r['status'] == 'Valide']
    if valid_mols_3:
        lipinski_compliant = sum(1 for r in valid_mols_3 if r['lipinski_compliant'] == 1)

        print(f"Mol√©cules valides g√©n√©r√©es: {valid_count}/120")
        print(f"Compliantes Lipinski: {lipinski_compliant}/{valid_count}")

    # Cr√©er le plot pour la condition 3
    plot_extended_generation_results(condition_3_results, "Objectif 3 - Lipinski Complet")

    # Condition 4: TOUS les objectifs combin√©s
    print(f"\n{'='*60}")
    print("OBJECTIF COMBIN√â: Tous les objectifs simultan√©ment")
    print("  - LogP ‚â§ 3")
    print("  - MW ‚â§ 480 g/mol")
    print("  - HBD ‚â§ 3, HBA ‚â§ 3, RotBonds ‚â§ 3")
    print("  - 2 cycles aromatiques, 1 cycle non-aromatique")
    print("  - Au moins un groupe fonctionnel")
    print("  - R-value dans [0.05‚Äì0.50]")
    print(f"{'='*60}")

    target_cats_4 = [
        1.0,  # LogP ‚â§ 3
        1.0,  # MW ‚â§ 480
        3.0,  # HBD ‚â§ 3
        3.0,  # HBA ‚â§ 3
        3.0,  # RotBonds ‚â§ 3
        2.0,  # 2 cycles aromatiques
        1.0,  # 1 cycle non-aromatique
        1.0,  # Avec groupe fonctionnel
        1.0,  # R-value dans [0.05-0.50]
        1.0   # Lipinski compliant
    ]

    condition_tensor_4 = torch.tensor(target_cats_4, dtype=torch.float32).unsqueeze(0)

    print(f"G√©n√©ration pour TOUS les objectifs combin√©s")
    condition_4_results = []
    valid_count = 0
    for i in range(120):
        mol_str = generate_conditional(model, condition_tensor_4, stoi, itos, max_new_tokens=50, top_k=10)
        valid, logp, mw, hbd, hba, rot_bonds, aromatic_rings, non_aromatic_rings, functional_group, r_value, lipinski = check_mol_all_props(mol_str)

        result = {
            'smiles': mol_str,
            'status': valid,
            'logp': logp,
            'mw': mw,
            'hbd': hbd,
            'hba': hba,
            'rot_bonds': rot_bonds,
            'aromatic_rings': aromatic_rings,
            'non_aromatic_rings': non_aromatic_rings,
            'functional_group': functional_group,
            'r_value': r_value,
            'lipinski_compliant': lipinski
        }
        condition_4_results.append(result)

        if i < 5:  # Afficher seulement les 5 premiers
            print(f"  {i+1}. -> '{mol_str}'")
            print(f"     (Valide: {valid})")

        if valid == "Valide":
            valid_count += 1

    # Calcul des statistiques pour l'objectif combin√©
    valid_mols_4 = [r for r in condition_4_results if r['status'] == 'Valide']
    if valid_mols_4:
        all_targets_met = sum(1 for r in valid_mols_4 if (
            r['logp'] <= 3 and
            r['mw'] <= 480 and
            r['hbd'] <= 3 and
            r['hba'] <= 3 and
            r['rot_bonds'] <= 3 and
            r['aromatic_rings'] == 2 and
            r['non_aromatic_rings'] == 1 and
            r['functional_group'] == 1 and
            0.05 <= r['r_value'] <= 0.50
        ))

        print(f"Mol√©cules valides g√©n√©r√©es: {valid_count}/120")
        print(f"Tous les objectifs atteints: {all_targets_met}/{valid_count}")

    # Cr√©er le plot pour la condition 4
    plot_extended_generation_results(condition_4_results, "Objectif Combin√© - Tous Crit√®res")

    # R√©sum√© final
    print(f"\n{'='*80}")
    print("R√âSUM√â FINAL DES R√âSULTATS")
    print(f"{'='*80}")

    all_results = [
        ("Objectif 1 - LogP ‚â§ 3", condition_1_results),
        ("Objectif 2 - Structuraux [13]", condition_2_results),
        ("Objectif 3 - Lipinski Complet", condition_3_results),
        ("Objectif Combin√© - Tous Crit√®res", condition_4_results)
    ]

    for name, results in all_results:
        valid_mols = [r for r in results if r['status'] == 'Valide']
        total = len(results)
        valid_count = len(valid_mols)

        print(f"\n{name}:")
        print(f"  - Mol√©cules valides: {valid_count}/{total} ({valid_count/total*100:.1f}%)")

        if valid_mols:
            avg_logp = sum(r['logp'] for r in valid_mols) / len(valid_mols)
            avg_mw = sum(r['mw'] for r in valid_mols) / len(valid_mols)
            print(f"  - LogP moyen: {avg_logp:.2f}")
            print(f"  - MW moyen: {avg_mw:.1f}")
            print(f"  - Compliance Lipinski: {sum(r['lipinski_compliant'] for r in valid_mols)}/{len(valid_mols)}")

    print(f"\n‚úì Tous les fichiers sauvegard√©s dans : {DRIVE_PATH}")
    print("‚úì Entra√Ænement et g√©n√©ration termin√©s avec succ√®s !")
    print("‚úì Tous les objectifs ont √©t√© int√©gr√©s et test√©s !")

Temperature sampling

    condition_vector = CONDITIONS[4]["get_vector"](None) 
    here 4 represent the generation with the condition 4 to generate with cond 1 that verfify first objective write 1 instead and so on

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : ANALYSE DE SENSIBILIT√â TEMP√âRATURE - CONDITION 4
Trouve la temp√©rature optimale pour l'√©quilibre validit√©/novelty/unicit√©/IntDiv avec Condition 4
"""

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from dataclasses import dataclass
import os
import json
from tqdm import tqdm
import numpy as np
from collections import Counter
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, DataStructs, Crippen, Lipinski
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
import matplotlib.pyplot as plt

# --- CONFIGURATION ---
DRIVE_PATH = '/content/drive/MyDrive/cond_gpt_model1'
VOCAB_FILE = os.path.join(DRIVE_PATH, 'vocab_dataset.json')
CHECKPOINT_FILE = os.path.join(DRIVE_PATH, 'checkpoints', 'cond_gpt_categorical_extended.pth')
DATA_FILE = '/content/drive/MyDrive/cond_gpt_model1/s_100_str_+1M_fixed.txt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- FONCTIONS IDENTIQUES ---
def has_functional_group(mol):
    smarts_patterns = [
        '[OH]',
        '[#6]C(=O)[O;H0]',
        'C(=O)[OH]',
        '[NH2]'
    ]
    for pattern in smarts_patterns:
        if mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)):
            return 1.0
    return 0.0

def calculate_r_value(mol):
    try:
        mol_wt = Descriptors.MolWt(mol)
        logp = Crippen.MolLogP(mol)
        if mol_wt > 0:
            r_value = logp / (mol_wt / 100)
            return r_value
        else:
            return 0.0
    except:
        return 0.0

# --- D√âFINITION DES CONDITIONS ---
CONDITIONS = {
    4: {
        "name": "Condition 4: Structural + Lipinski",
        "description": "Combination of conditions 2 and 3",
        "get_vector": lambda mol: [1.0, 1.0, 3.0, 3.0, 3.0, 2.0, 1.0, 1.0, 1.0, 1.0]
    }
}

# --- ARCHITECTURE MOD√àLE ---
@dataclass
class GPTConfig:
    block_size: int = 128
    vocab_size: int = 57
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1
    condition_dim: int = 10

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"S√©quence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        assert conditions is not None, "Les conditions doivent √™tre fournies !"
        cond_emb = self.condition_projector(conditions)
        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss

# --- FONCTIONS POUR CALCULER LES M√âTRIQUES ---
def calculate_fingerprint(mol):
    """Calcule le fingerprint Morgan pour une mol√©cule"""
    try:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
        return fp
    except:
        return None

def calculate_intdiv(smiles_list):
    """Calcule la diversit√© interne (IntDiv) d'une liste de SMILES"""
    if len(smiles_list) < 2:
        return 0.0

    fingerprints = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = calculate_fingerprint(mol)
            if fp is not None:
                fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    similarities = []
    for i in range(len(fingerprints)):
        for j in range(i+1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if similarities:
        mean_similarity = np.mean(similarities)
        intdiv = 1 - mean_similarity
        return intdiv
    else:
        return 0.0

def calculate_novelty(generated_smiles, reference_smiles_file):
    """Calcule la nouveaut√© des mol√©cules g√©n√©r√©es"""
    try:
        # V√©rifier si le fichier existe
        full_reference_path = os.path.join(DRIVE_PATH, reference_smiles_file)
        if not os.path.exists(full_reference_path):
            print(f" Fichier de r√©f√©rence non trouv√©: {full_reference_path}")
            print(f" Recherch√© dans: {DRIVE_PATH}")
            return 100.0, len(generated_smiles)  # Si pas de r√©f√©rence, tout est nouveau

        # Charger les SMILES de r√©f√©rence avec gestion d'erreurs
        reference_smiles = set()
        valid_reference_count = 0

        print(f" Chargement du fichier de r√©f√©rence: {reference_smiles_file}")
        with open(full_reference_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if line:
                    # Valider que c'est un SMILES valide
                    mol = Chem.MolFromSmiles(line)
                    if mol:
                        canon_smiles = Chem.MolToSmiles(mol)
                        reference_smiles.add(canon_smiles)
                        valid_reference_count += 1

        print(f"‚úì SMILES de r√©f√©rence charg√©s: {len(reference_smiles)} (valides: {valid_reference_count})")

        if len(reference_smiles) == 0:
            print("  Aucun SMILES valide dans le fichier de r√©f√©rence")
            return 100.0, len(generated_smiles)

        # Calculer la nouveaut√©
        novel_count = 0
        valid_generated_count = 0

        for smiles in tqdm(generated_smiles, desc="Calcul nouveaut√©", leave=False):
            mol = Chem.MolFromSmiles(smiles)
            if mol:
                valid_generated_count += 1
                canon_smiles = Chem.MolToSmiles(mol)
                if canon_smiles not in reference_smiles:
                    novel_count += 1

        print(f"‚úì Mol√©cules g√©n√©r√©es valides: {valid_generated_count}/{len(generated_smiles)}")

        if valid_generated_count == 0:
            return 0.0, 0

        novelty = (novel_count / valid_generated_count) * 100
        return novelty, novel_count

    except Exception as e:
        print(f" Erreur lors du calcul de la nouveaut√©: {e}")
        return 0.0, 0

def calculate_uniqueness(smiles_list):
    """Calcule l'unicit√© des mol√©cules g√©n√©r√©es"""
    if not smiles_list:
        return 0.0, 0

    canonical_smiles = []
    for smiles in tqdm(smiles_list, desc="Calcul unicit√©", leave=False):
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            canon_smiles = Chem.MolToSmiles(mol)
            canonical_smiles.append(canon_smiles)

    unique_smiles = set(canonical_smiles)
    uniqueness = (len(unique_smiles) / len(smiles_list)) * 100 if smiles_list else 0.0
    return uniqueness, len(unique_smiles)

def calculate_validity(smiles_list):
    """Calcule la validit√© des mol√©cules g√©n√©r√©es"""
    if not smiles_list:
        return 0.0, 0

    valid_count = 0
    for smiles in tqdm(smiles_list, desc="Calcul validit√©", leave=False):
        if Chem.MolFromSmiles(smiles) is not None:
            valid_count += 1

    validity = (valid_count / len(smiles_list)) * 100 if smiles_list else 0.0
    return validity, valid_count

def calculate_balanced_score(validity, novelty, uniqueness, intdiv):
    """
    Calcule un score d'√©quilibre qui compromet entre les 4 m√©triques
    Notre objectif est de trouver un compromis optimal entre:
    - Validit√© (priorit√© haute)
    - Novelty (priorit√© moyenne)
    - Unicit√© (priorit√© moyenne)
    - IntDiv (priorit√© moyenne)
    """
    # Poids pour favoriser un bon √©quilibre
    weights = {
        'validity': 1.2,    # L√©g√®rement plus important
        'novelty': 1.0,
        'uniqueness': 1.0,
        'intdiv': 1.0
    }

    # Normaliser IntDiv (0-1 scale) pour le mettre √† la m√™me √©chelle que les pourcentages
    intdiv_normalized = intdiv * 100  # Convertir en √©chelle 0-100

    # Score pond√©r√© qui favorise l'√©quilibre
    if validity > 0 and novelty > 0 and uniqueness > 0 and intdiv_normalized > 0:
        weighted_score = (
            weights['validity'] * validity +
            weights['novelty'] * novelty +
            weights['uniqueness'] * uniqueness +
            weights['intdiv'] * intdiv_normalized
        ) / sum(weights.values())

        # P√©naliser les d√©s√©quilibres extr√™mes
        metrics = [validity, novelty, uniqueness, intdiv_normalized]
        std_penalty = np.std(metrics) * 0.1  # P√©nalit√© douce pour les d√©s√©quilibres

        balanced_score = weighted_score - std_penalty
        return max(0, balanced_score)  # Assurer un score positif
    else:
        return 0.0

# --- G√âN√âRATION AVEC DIFF√âRENTES TEMP√âRATURES ---
@torch.no_grad()
def generate_with_temperature(model, condition_tensor, stoi, itos, start_idx, end_idx,
                            temperature=0.5, num_molecules=1000):
    """G√©n√®re des mol√©cules avec une temp√©rature sp√©cifique"""
    generated_smiles = []

    with tqdm(total=num_molecules, desc=f"Temp {temperature}", leave=False) as pbar:
        while len(generated_smiles) < num_molecules:
            # Param√®tres
            top_k = 30

            # G√©n√©ration
            idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
            condition_local = condition_tensor.to(DEVICE)

            for step in range(80):
                idx_cond = idx if idx.size(1) <= 128 else idx[:, -128:]
                logits, _ = model(idx_cond, conditions=condition_local)
                logits = logits[:, -1, :] / max(temperature, 0.1)

                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

                probs = F.softmax(logits, dim=-1)
                probs_mod = probs.clone()
                probs_mod[0, start_idx] = 0.0
                if step < 8:
                    probs_mod[0, end_idx] = 0.0

                if probs_mod.sum() > 0:
                    probs_mod = probs_mod / probs_mod.sum()
                else:
                    probs_mod = probs

                idx_next = torch.multinomial(probs_mod, num_samples=1)
                next_token = idx_next.item()

                if next_token == end_idx and step >= 8:
                    break

                idx = torch.cat((idx, idx_next), dim=1)

            # D√©codage
            tokens = idx[0].tolist()
            if len(tokens) > 1:
                tokens_to_decode = tokens[1:]
                if end_idx in tokens_to_decode:
                    end_pos = tokens_to_decode.index(end_idx)
                    tokens_to_decode = tokens_to_decode[:end_pos]
            else:
                tokens_to_decode = []

            smiles = ''.join([itos[str(i)] for i in tokens_to_decode if str(i) in itos])

            if smiles:
                generated_smiles.append(smiles)
                pbar.update(1)

    return generated_smiles

def evaluate_temperature_performance(model, condition_tensor, stoi, itos, start_idx, end_idx,
                                   temperature, num_molecules=1000):
    """√âvalue les performances pour une temp√©rature donn√©e"""
    # G√©n√©ration
    smiles_list = generate_with_temperature(
        model, condition_tensor, stoi, itos, start_idx, end_idx,
        temperature, num_molecules
    )

    print(f"    Calcul des m√©triques pour temp√©rature {temperature}...")

    # Calcul des m√©triques
    validity, valid_count = calculate_validity(smiles_list)
    novelty, novel_count = calculate_novelty(smiles_list, DATA_FILE)
    uniqueness, unique_count = calculate_uniqueness(smiles_list)
    intdiv = calculate_intdiv(smiles_list)

    # Score d'√©quilibre
    balanced_score = calculate_balanced_score(validity, novelty, uniqueness, intdiv)

    return {
        'temperature': temperature,
        'validity': validity,
        'novelty': novelty,
        'uniqueness': uniqueness,
        'intdiv': intdiv,
        'balanced_score': balanced_score,
        'valid_count': valid_count,
        'novel_count': novel_count,
        'unique_count': unique_count,
        'total_generated': len(smiles_list)
    }

# --- ANALYSE DE SENSIBILIT√â ---
def temperature_sensitivity_analysis(model, condition_tensor, stoi, itos, start_idx, end_idx):
    """Analyse la sensibilit√© aux diff√©rentes temp√©ratures"""
    print("üéØ ANALYSE DE SENSIBILIT√â - TEMP√âRATURE (CONDITION 4)")
    print("="*80)
    print("Objectif: Trouver un compromis optimal entre les 4 m√©triques")
    print("Condition: Structural + Lipinski (combinaison des conditions 2 et 3)")
    print("="*80)

    # Temp√©ratures sp√©cifiques demand√©es
    temperatures = [0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
    results = []

    print(f"\nüß™ Test de {len(temperatures)} temp√©ratures sp√©cifiques")
    print(f"Temp√©ratures: {temperatures}")
    print("G√©n√©ration de 1000 mol√©cules par temp√©rature...")
    print("M√©triques √©valu√©es: Validit√©, Novelty, Unicit√©, IntDiv")

    # V√©rifier le fichier de r√©f√©rence d'abord
    reference_path = os.path.join(DRIVE_PATH, DATA_FILE)
    if not os.path.exists(reference_path):
        print(f"‚ùå ATTENTION: Fichier de r√©f√©rence non trouv√©: {reference_path}")
        print("‚ö†Ô∏è  La nouveaut√© sera calcul√©e comme 100% (toutes les mol√©cules consid√©r√©es comme nouvelles)")
    else:
        print(f"‚úì Fichier de r√©f√©rence trouv√©: {DATA_FILE}")

    for temp in temperatures:
        print(f"\nüìä Temp√©rature: {temp}")
        result = evaluate_temperature_performance(
            model, condition_tensor, stoi, itos, start_idx, end_idx, temp, 1000
        )
        results.append(result)

        print(f"   ‚úì Validit√©: {result['validity']:.1f}% ({result['valid_count']}/1000)")
        print(f"   ‚úì Novelty: {result['novelty']:.1f}% ({result['novel_count']}/1000)")
        print(f"   ‚úì Unicit√©: {result['uniqueness']:.1f}% ({result['unique_count']}/1000)")
        print(f"   ‚úì IntDiv: {result['intdiv']:.3f}")
        print(f"     Score de compromis: {result['balanced_score']:.1f}")

    return results

def display_comprehensive_results(results):
    """Affiche un tableau complet des r√©sultats"""
    print(f"\n{'='*100}")
    print("üìä TABLEAU COMPLET DES R√âSULTATS - CONDITION 4")
    print("="*100)
    print(f"{'Temp':<6} {'Validit√©':<10} {'Novelty':<10} {'Unicit√©':<10} {'IntDiv':<10} {'Compromis':<12} {'Valides':<8} {'Nouvelles':<9} {'Uniques':<8}")
    print("-"*100)

    for result in results:
        print(f"{result['temperature']:<6} {result['validity']:<10.1f} {result['novelty']:<10.1f} "
              f"{result['uniqueness']:<10.1f} {result['intdiv']:<10.3f} {result['balanced_score']:<12.1f} "
              f"{result['valid_count']:<8} {result['novel_count']:<9} {result['unique_count']:<8}")

def find_optimal_temperature(results):
    """Trouve la temp√©rature optimale bas√©e sur le score de compromis"""
    if not results:
        return 0.5

    # Trouver le r√©sultat avec le meilleur score de compromis
    best_result = max(results, key=lambda x: x['balanced_score'])

    print(f"\nüéØ TEMP√âRATURE OPTIMALE TROUV√âE:")
    print(f"   Temp√©rature: {best_result['temperature']}")
    print(f"   Score de compromis: {best_result['balanced_score']:.1f}")
    print(f"   Validit√©: {best_result['validity']:.1f}%")
    print(f"   Novelty: {best_result['novelty']:.1f}%")
    print(f"   Unicit√©: {best_result['uniqueness']:.1f}%")
    print(f"   IntDiv: {best_result['intdiv']:.3f}")

    return best_result['temperature']

def plot_temperature_analysis(results):
    """Cr√©e des graphiques pour visualiser les r√©sultats"""
    if not results:
        print(" Aucune donn√©e √† visualiser")
        return

    # Organiser les donn√©es
    temperatures = [r['temperature'] for r in results]
    validity = [r['validity'] for r in results]
    novelty = [r['novelty'] for r in results]
    uniqueness = [r['uniqueness'] for r in results]
    intdiv = [r['intdiv'] * 100 for r in results]  # Normaliser pour le graphique
    balanced_score = [r['balanced_score'] for r in results]

    # Cr√©er les graphiques
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Analyse de Sensibilit√© - Temp√©rature (Condition 4)', fontsize=16, fontweight='bold')

    # Graphique 1: Validit√©
    axes[0, 0].plot(temperatures, validity, marker='o', linewidth=2, color='blue')
    axes[0, 0].set_xlabel('Temp√©rature')
    axes[0, 0].set_ylabel('Validit√© (%)')
    axes[0, 0].set_title('Validit√© vs Temp√©rature')
    axes[0, 0].grid(True, alpha=0.3)

    # Graphique 2: Novelty
    axes[0, 1].plot(temperatures, novelty, marker='o', linewidth=2, color='red')
    axes[0, 1].set_xlabel('Temp√©rature')
    axes[0, 1].set_ylabel('Novelty (%)')
    axes[0, 1].set_title('Novelty vs Temp√©rature')
    axes[0, 1].grid(True, alpha=0.3)

    # Graphique 3: Unicit√©
    axes[1, 0].plot(temperatures, uniqueness, marker='o', linewidth=2, color='green')
    axes[1, 0].set_xlabel('Temp√©rature')
    axes[1, 0].set_ylabel('Unicit√© (%)')
    axes[1, 0].set_title('Unicit√© vs Temp√©rature')
    axes[1, 0].grid(True, alpha=0.3)

    # Graphique 4: Score de compromis
    axes[1, 1].plot(temperatures, balanced_score, marker='o', linewidth=2, color='orange')
    axes[1, 1].set_xlabel('Temp√©rature')
    axes[1, 1].set_ylabel('Score de Compromis')
    axes[1, 1].set_title('Score de Compromis vs Temp√©rature')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()

    # Sauvegarder le graphique
    output_path = os.path.join(DRIVE_PATH, "temperature_sensitivity_condition4.png")
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Graphique sauvegard√©: {output_path}")

    plt.show()

# --- MAIN ---
def main():
    print(" ANALYSE DE SENSIBILIT√â - TEMP√âRATURE OPTIMALE (CONDITION 4)")
    print("="*80)
    print("Objectif: Trouver la temp√©rature qui offre le meilleur COMPROMIS entre:")
    print("  ‚Ä¢ Validit√©  ‚Ä¢ Novelty  ‚Ä¢ Unicit√©  ‚Ä¢ IntDiv")
    print("Condition: Structural + Lipinski")
    print("  - LogP‚â§3, MW‚â§480, HBA‚â§3, HBD‚â§3, RotB‚â§3 (Lipinski Ro3)")
    print("  - 2 aromatic rings, 1 non-aromatic, functional groups, R-value [0.05-0.50]")
    print("="*80)

    # Chargement vocabulaire
    print("\nüìö Chargement du vocabulaire...")
    with open(VOCAB_FILE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    stoi = vocab_data['stoi']
    itos = vocab_data['itos']

    start_token = stoi['<start>']
    end_token = stoi['<end>']
    print(f"‚úì Vocabulaire charg√© (Tokens: Start={start_token}, End={end_token})")

    # Chargement mod√®le
    print("\nü§ñ Chargement du mod√®le...")
    config = GPTConfig(vocab_size=len(stoi))
    model = ConditionalDrugGPT(config)
    model.to(DEVICE)

    try:
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    print("‚úì Mod√®le charg√© et pr√™t")

    # Condition 4 pour la g√©n√©ration (Structural + Lipinski)
    condition_vector = CONDITIONS[4]["get_vector"](None)   
    condition_tensor = torch.tensor([condition_vector], dtype=torch.float32)

    print(f"\nüéØ Condition de g√©n√©ration: {CONDITIONS[4]['name']}")
    print(f"üìù Description: {CONDITIONS[4]['description']}")

    # Analyse de sensibilit√©
    print(f"\n{'='*80}")
    print("üéØ PHASE 1 : ANALYSE DE SENSIBILIT√â")
    print(f"{'='*80}")

    results = temperature_sensitivity_analysis(
        model, condition_tensor, stoi, itos, start_token, end_token
    )

    # Affichage du tableau complet
    display_comprehensive_results(results)

    # Trouver la temp√©rature optimale
    optimal_temp = find_optimal_temperature(results)

    # G√©n√©rer le graphique
    plot_temperature_analysis(results)

    # Sauvegarde des r√©sultats
    output_file = os.path.join(DRIVE_PATH, "temperature_optimization_condition4_compromise.txt")
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("ANALYSE DE SENSIBILIT√â - TEMP√âRATURE OPTIMALE (CONDITION 4 - COMPROMIS)\n")
        f.write("="*100 + "\n\n")
        f.write(f"Condition: {CONDITIONS[4]['name']}\n")
        f.write(f"Description: {CONDITIONS[4]['description']}\n\n")
        f.write(f"Temp√©rature optimale trouv√©e: {optimal_temp}\n")
        f.write(f"Score de compromis: {max(results, key=lambda x: x['balanced_score'])['balanced_score']:.1f}\n\n")

        f.write("R√âSULTATS D√âTAILL√âS:\n")
        f.write("-"*100 + "\n")
        f.write(f"{'Temp':<6} {'Validit√©':<10} {'Novelty':<10} {'Unicit√©':<10} {'IntDiv':<10} {'Compromis':<12} {'Valides':<8} {'Nouvelles':<9} {'Uniques':<8}\n")
        f.write("-"*100 + "\n")

        for result in results:
            f.write(f"{result['temperature']:<6} {result['validity']:<10.1f} {result['novelty']:<10.1f} "
                   f"{result['uniqueness']:<10.1f} {result['intdiv']:<10.3f} {result['balanced_score']:<12.1f} "
                   f"{result['valid_count']:<8} {result['novel_count']:<9} {result['unique_count']:<8}\n")

    print(f"\nüíæ R√©sultats sauvegard√©s: {output_file}")

    print("\n‚úÖ ANALYSE TERMIN√âE !")
    print("üéØ La temp√©rature optimale a √©t√© trouv√©e pour la Condition 4")

if __name__ == "__main__":
    main()

Condition satisfaction

you can specify the condition of the generation here condition_vector = CONDITIONS[2]["get_vector"](None)
</br> 

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
SCRIPT : ANALYSE HI√âRARCHIQUE COMPL√àTE
Valide ‚Üí Novel ‚Üí Unique ‚Üí Conditions sur Uniques
"""

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from dataclasses import dataclass
import os
import json
from tqdm import tqdm
import numpy as np
from collections import Counter
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, DataStructs, Crippen, Lipinski
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')

# --- CONFIGURATION ---
DRIVE_PATH = '/content/drive/MyDrive/cond_gpt_model1'
VOCAB_FILE = os.path.join(DRIVE_PATH, 'vocab_dataset.json')
CHECKPOINT_FILE = os.path.join(DRIVE_PATH, 'checkpoints', 'cond_gpt_categorical_extended.pth')
DATA_FILE = '/content/drive/MyDrive/cond_gpt_model1/s_100_str_+1M_fixed.txt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- FONCTIONS IDENTIQUES ---
def has_functional_group(mol):
    smarts_patterns = [
        '[OH]',
        '[#6]C(=O)[O;H0]',
        'C(=O)[OH]',
        '[NH2]'
    ]
    for pattern in smarts_patterns:
        if mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)):
            return 1.0
    return 0.0

def calculate_r_value(mol):
    try:
        mol_wt = Descriptors.MolWt(mol)
        logp = Crippen.MolLogP(mol)
        if mol_wt > 0:
            r_value = logp / (mol_wt / 100)
            return r_value
        else:
            return 0.0
    except:
        return 0.0

# --- D√âFINITION DES CONDITIONS ---
CONDITIONS = {
    0: {
        "name": "Condition 1: LogP ‚â§ 3",
        "description": "Single objective: logP ‚â§ 3",
        "get_vector": lambda mol: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    },
    1: {
        "name": "Condition 2: Structural",
        "description": "2 aromatic rings, 1 non-aromatic, functional groups, R-value [0.05-0.50]",
        "get_vector": lambda mol: [0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 1.0, 1.0, 1.0, 0.0]
    },
    2: {
        "name": "Condition 3: Lipinski Ro3",
        "description": "LogP‚â§3, MW‚â§480, HBA‚â§3, HBD‚â§3, RotB‚â§3",
        "get_vector": lambda mol: [1.0, 1.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 1.0]
    },
    3: {
        "name": "Condition 4: Structural + Lipinski",
        "description": "Combination of conditions 2 and 3",
        "get_vector": lambda mol: [1.0, 1.0, 3.0, 3.0, 3.0, 2.0, 1.0, 1.0, 1.0, 1.0]
    }
}

def evaluate_condition(mol, condition_idx):
    """√âvalue si une mol√©cule satisfait une condition sp√©cifique"""
    try:
        if condition_idx == 0:  # LogP ‚â§ 3
            logp = Crippen.MolLogP(mol)
            return logp <= 3.0

        elif condition_idx == 1:  # Structural Objectives
            aromatic_rings = Lipinski.NumAromaticRings(mol)
            non_aromatic_rings = Lipinski.NumAliphaticRings(mol)
            ring_condition = (aromatic_rings == 2) and (non_aromatic_rings == 1)
            functional_group_condition = has_functional_group(mol)
            r_value = calculate_r_value(mol)
            r_value_condition = (0.05 <= r_value <= 0.50)
            return ring_condition and functional_group_condition and r_value_condition

        elif condition_idx == 2:  # Lipinski's Rule of Three
            logp = Crippen.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Lipinski.NumHDonors(mol)
            hba = Lipinski.NumHAcceptors(mol)
            rotb = Lipinski.NumRotatableBonds(mol)
            return (logp <= 3.0) and (mw <= 480) and (hbd <= 3) and (hba <= 3) and (rotb <= 3)

        elif condition_idx == 3:  # Structural + Lipinski
            aromatic_rings = Lipinski.NumAromaticRings(mol)
            non_aromatic_rings = Lipinski.NumAliphaticRings(mol)
            ring_condition = (aromatic_rings == 2) and (non_aromatic_rings == 1)
            functional_group_condition = has_functional_group(mol)
            r_value = calculate_r_value(mol)
            r_value_condition = (0.05 <= r_value <= 0.50)
            logp = Crippen.MolLogP(mol)
            mw = Descriptors.MolWt(mol)
            hbd = Lipinski.NumHDonors(mol)
            hba = Lipinski.NumHAcceptors(mol)
            rotb = Lipinski.NumRotatableBonds(mol)
            lipinski_condition = (logp <= 3.0) and (mw <= 480) and (hbd <= 3) and (hba <= 3) and (rotb <= 3)
            return ring_condition and functional_group_condition and r_value_condition and lipinski_condition

        else:
            return False
    except:
        return False

# --- ARCHITECTURE MOD√àLE ---
@dataclass
class GPTConfig:
    block_size: int = 128
    vocab_size: int = 57
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1
    condition_dim: int = 10

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"S√©quence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        assert conditions is not None, "Les conditions doivent √™tre fournies !"
        cond_emb = self.condition_projector(conditions)
        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss

# --- ANALYSE HI√âRARCHIQUE COMPL√àTE ---
def comprehensive_hierarchical_analysis(generated_smiles, reference_smiles_file):
    """Analyse hi√©rarchique compl√®te : Valid ‚Üí Novel ‚Üí Unique ‚Üí Conditions"""
    print("üîç ANALYSE HI√âRARCHIQUE COMPL√àTE")
    print("=" * 80)

    # Charger le dataset de r√©f√©rence
    reference_smiles = set()
    if os.path.exists(reference_smiles_file):
        print("üìñ Chargement du dataset de r√©f√©rence...")
        with open(reference_smiles_file, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="Chargement r√©f√©rence"):
                line = line.strip()
                if line:
                    mol = Chem.MolFromSmiles(line)
                    if mol:
                        canon_smiles = Chem.MolToSmiles(mol)
                        reference_smiles.add(canon_smiles)
        print(f"‚úì R√©f√©rence charg√©e: {len(reference_smiles)} mol√©cules")

    # Niveau 1: Validation
    print(f"\nüéØ NIVEAU 1: VALIDATION")
    print("-" * 40)
    valid_molecules = []
    valid_smiles = []

    for smiles in tqdm(generated_smiles, desc="Validation SMILES"):
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            valid_molecules.append(mol)
            valid_smiles.append(smiles)

    total_generated = len(generated_smiles)
    total_valid = len(valid_molecules)
    validity_percentage = (total_valid / total_generated) * 100

    print(f"‚Ä¢ G√©n√©r√©es: {total_generated}")
    print(f"‚Ä¢ Valides: {total_valid} ({validity_percentage:.1f}%)")

    # Niveau 2: Novelty
    print(f"\nüéØ NIVEAU 2: NOVELTY")
    print("-" * 40)
    novel_molecules = []
    novel_smiles = []

    for mol, smiles in tqdm(zip(valid_molecules, valid_smiles), desc="V√©rification novelty", total=total_valid):
        canon_smiles = Chem.MolToSmiles(mol)
        if canon_smiles not in reference_smiles:
            novel_molecules.append(mol)
            novel_smiles.append(smiles)

    total_novel = len(novel_molecules)
    novelty_percentage = (total_novel / total_valid) * 100

    print(f"‚Ä¢ Valides: {total_valid}")
    print(f"‚Ä¢ Novel: {total_novel} ({novelty_percentage:.1f}%)")

    # Niveau 3: Uniqueness parmi les Novel
    print(f"\nüéØ NIVEAU 3: UNICIT√â (parmi les Novel)")
    print("-" * 40)
    unique_novel_molecules = {}
    unique_novel_smiles = []

    for mol, smiles in tqdm(zip(novel_molecules, novel_smiles), desc="D√©duplication novel", total=total_novel):
        canon_smiles = Chem.MolToSmiles(mol)
        if canon_smiles not in unique_novel_molecules:
            unique_novel_molecules[canon_smiles] = mol
            unique_novel_smiles.append(smiles)

    total_unique_novel = len(unique_novel_molecules)
    uniqueness_percentage = (total_unique_novel / total_novel) * 100

    print(f"‚Ä¢ Novel: {total_novel}")
    print(f"‚Ä¢ Uniques parmi novel: {total_unique_novel} ({uniqueness_percentage:.1f}%)")

    # Niveau 4: Satisfaction des conditions sur les Uniques Novel
    print(f"\nüéØ NIVEAU 4: SATISFACTION DES CONDITIONS (sur Uniques Novel)")
    print("-" * 40)

    condition_results = {
        'total_unique_novel': total_unique_novel,
        'condition_counts': {0: 0, 1: 0, 2: 0, 3: 0},
        'condition_percentages': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0},
        'examples_per_condition': {0: [], 1: [], 2: [], 3: []},
        'properties_per_condition': {0: [], 1: [], 2: [], 3: []}
    }

    for condition_idx in range(4):
        print(f"\nüîç {CONDITIONS[condition_idx]['name']}")
        condition_count = 0

        for canon_smiles, mol in tqdm(unique_novel_molecules.items(),
                                    desc=f"Condition {condition_idx+1}",
                                    leave=False):
            if evaluate_condition(mol, condition_idx):
                condition_count += 1

                # Garder des exemples
                if len(condition_results['examples_per_condition'][condition_idx]) < 3:
                    condition_results['examples_per_condition'][condition_idx].append(canon_smiles)

                    # Propri√©t√©s
                    props = {
                        'LogP': Crippen.MolLogP(mol),
                        'MW': Descriptors.MolWt(mol),
                        'HBD': Lipinski.NumHDonors(mol),
                        'HBA': Lipinski.NumHAcceptors(mol),
                        'RotB': Lipinski.NumRotatableBonds(mol),
                    }
                    if condition_idx in [1, 3]:  # Conditions structurelles
                        props.update({
                            'Aromatic': Lipinski.NumAromaticRings(mol),
                            'Aliphatic': Lipinski.NumAliphaticRings(mol),
                            'R-value': calculate_r_value(mol),
                            'Functional': has_functional_group(mol)
                        })
                    condition_results['properties_per_condition'][condition_idx].append(props)

        condition_results['condition_counts'][condition_idx] = condition_count
        if total_unique_novel > 0:
            condition_results['condition_percentages'][condition_idx] = (condition_count / total_unique_novel) * 100

        print(f"   ‚úì Satisfait: {condition_count}/{total_unique_novel} ({condition_results['condition_percentages'][condition_idx]:.1f}%)")

    # R√©sum√© complet
    results = {
        'total_generated': total_generated,
        'total_valid': total_valid,
        'validity_percentage': validity_percentage,
        'total_novel': total_novel,
        'novelty_percentage': novelty_percentage,
        'total_unique_novel': total_unique_novel,
        'uniqueness_percentage': uniqueness_percentage,
        'condition_results': condition_results,
        'unique_novel_molecules': unique_novel_molecules
    }

    return results

def display_hierarchical_results(results):
    """Affiche les r√©sultats de l'analyse hi√©rarchique"""
    print(f"\n{'='*80}")
    print("üìä R√âSULTATS HI√âRARCHIQUES COMPLETS")
    print("="*80)

    print(f"\nüéØ HI√âRARCHIE DE QUALIT√â:")
    print("-" * 50)
    print(f"1. G√©n√©r√©es:      {results['total_generated']:>6} mol√©cules (100.0%)")
    print(f"2. Valides:       {results['total_valid']:>6} mol√©cules ({results['validity_percentage']:5.1f}% des g√©n√©r√©es)")
    print(f"3. Novel:         {results['total_novel']:>6} mol√©cules ({results['novelty_percentage']:5.1f}% des valides)")
    print(f"4. Uniques Novel: {results['total_unique_novel']:>6} mol√©cules ({results['uniqueness_percentage']:5.1f}% des novel)")

    print(f"\nüéØ SATISFACTION DES CONDITIONS (sur {results['total_unique_novel']} uniques novel):")
    print("-" * 50)
    for condition_idx in range(4):
        count = results['condition_results']['condition_counts'][condition_idx]
        percentage = results['condition_results']['condition_percentages'][condition_idx]
        print(f"‚Ä¢ {CONDITIONS[condition_idx]['name']}:")
        print(f"     {count:>3} mol√©cules ({percentage:5.1f}% des uniques novel)")

    # Score final
    if results['total_unique_novel'] > 0:
        final_score = (results['total_unique_novel'] / results['total_generated']) * 100
        print(f"\n‚öñÔ∏è  SCORE FINAL: {final_score:.2f}% des mol√©cules g√©n√©r√©es sont UNIQUES, NOUVELLES et VALIDES")

def calculate_intdiv(smiles_list):
    """Calcule la diversit√© interne sur un √©chantillon"""
    print("üîç Calcul de la diversit√© interne (IntDiv)...")

    if len(smiles_list) < 2:
        return 0.0

    # √âchantillonnage pour performance
    sample_size = min(1000, len(smiles_list))
    if len(smiles_list) > 1000:
        indices = np.random.choice(len(smiles_list), sample_size, replace=False)
        sample_smiles = [smiles_list[i] for i in indices]
    else:
        sample_smiles = smiles_list

    fingerprints = []
    for smiles in tqdm(sample_smiles, desc="Fingerprints"):
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
            fingerprints.append(fp)

    if len(fingerprints) < 2:
        return 0.0

    similarities = []
    for i in tqdm(range(len(fingerprints)), desc="Similarit√©s"):
        for j in range(i+1, len(fingerprints)):
            similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
            similarities.append(similarity)

    if similarities:
        mean_similarity = np.mean(similarities)
        intdiv = 1 - mean_similarity
        print(f"‚úì IntDiv: {intdiv:.4f} (sur {len(sample_smiles)} mol√©cules)")
        return intdiv
    else:
        return 0.0

# --- G√âN√âRATION ---
@torch.no_grad()
def generate_molecules(model, condition_tensor, stoi, itos, start_idx, end_idx, num_molecules=10000, temperature=0.6):
    """G√©n√®re des mol√©cules avec temp√©rature sp√©cifique"""
    print(f"üéØ G√©n√©ration de {num_molecules} mol√©cules (temp√©rature: {temperature})")

    generated_smiles = []

    with tqdm(total=num_molecules, desc="G√©n√©ration") as pbar:
        while len(generated_smiles) < num_molecules:
            top_k = 30
            idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
            condition_local = condition_tensor.to(DEVICE)

            for step in range(80):
                idx_cond = idx if idx.size(1) <= 128 else idx[:, -128:]
                logits, _ = model(idx_cond, conditions=condition_local)
                logits = logits[:, -1, :] / temperature

                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

                probs = F.softmax(logits, dim=-1)
                probs_mod = probs.clone()
                probs_mod[0, start_idx] = 0.0
                if step < 8:
                    probs_mod[0, end_idx] = 0.0

                if probs_mod.sum() > 0:
                    probs_mod = probs_mod / probs_mod.sum()
                else:
                    probs_mod = probs

                idx_next = torch.multinomial(probs_mod, num_samples=1)
                next_token = idx_next.item()

                if next_token == end_idx and step >= 8:
                    break

                idx = torch.cat((idx, idx_next), dim=1)

            tokens = idx[0].tolist()
            if len(tokens) > 1:
                tokens_to_decode = tokens[1:]
                if end_idx in tokens_to_decode:
                    end_pos = tokens_to_decode.index(end_idx)
                    tokens_to_decode = tokens_to_decode[:end_pos]
            else:
                tokens_to_decode = []

            smiles = ''.join([itos[str(i)] for i in tokens_to_decode if str(i) in itos])

            if smiles:
                generated_smiles.append(smiles)
                pbar.update(1)

    print(f"‚úÖ {len(generated_smiles)} mol√©cules g√©n√©r√©es")
    return generated_smiles

# --- MAIN ---
def main():
    print("üöÄ ANALYSE HI√âRARCHIQUE COMPL√àTE")
    print("="*80)
    print("Hi√©rarchie: G√©n√©r√©es ‚Üí Valides ‚Üí Novel ‚Üí Uniques Novel ‚Üí Conditions")
    print("="*80)

    # Chargements
    print("\nüìö Chargement du vocabulaire...")
    with open(VOCAB_FILE, 'r', encoding='utf-8') as f:
        vocab_data = json.load(f)
    stoi = vocab_data['stoi']
    itos = vocab_data['itos']
    start_token = stoi['<start>']
    end_token = stoi['<end>']

    print("ü§ñ Chargement du mod√®le...")
    config = GPTConfig(vocab_size=len(stoi))
    model = ConditionalDrugGPT(config)
    model.to(DEVICE)

    try:
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        checkpoint = torch.load(CHECKPOINT_FILE, map_location=DEVICE, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    # Condition de g√©n√©ration
    condition_vector = CONDITIONS[2]["get_vector"](None)
    condition_tensor = torch.tensor([condition_vector], dtype=torch.float32)

    print(f"\nüéØ G√©n√©ration avec:")
    print(f"‚Ä¢ Temp√©rature: 0.6")
    print(f"‚Ä¢ Condition: {CONDITIONS[2]['name']}")
    print(f"‚Ä¢ Mol√©cules: 10,000")

    # G√©n√©ration
    generated_smiles = generate_molecules(
        model, condition_tensor, stoi, itos, start_token, end_token,
        num_molecules=10000, temperature=0.6
    )

    # Analyse hi√©rarchique
    results = comprehensive_hierarchical_analysis(generated_smiles, DATA_FILE)

    # Calcul IntDiv sur les uniques novel
    unique_novel_smiles = list(results['unique_novel_molecules'].keys())
    intdiv = calculate_intdiv(unique_novel_smiles)
    results['intdiv'] = intdiv

    # Affichage r√©sultats
    display_hierarchical_results(results)

    print(f"\nüìà Diversit√© (IntDiv) des uniques novel: {intdiv:.4f}")
    print("\n‚úÖ ANALYSE TERMIN√âE !")

if __name__ == "__main__":
    main()