# download dataset

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ouical/data-gen")

print("Path to dataset files:", path)

BackendError: POST failed with: {"errors":["New Datasets cannot be attached in non-interactive sessions. Found no versions attached for Dataset [ouical/data-gen]."],"error":{"code":9},"wasSuccessful":false}

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SCRIPT COMPLET AMÉLIORÉ : TRANSFORMATEUR CONDITIONNEL (Catégoriel) -
"""

# --- PARTIE 1 : IMPORTS ET CONFIGURATION ---
import matplotlib.pyplot as plt
import re
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import math
from dataclasses import dataclass
import time
import os
import json
from tqdm import tqdm
import gc
from collections import Counter
from torch.utils.data import WeightedRandomSampler

try:
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from rdkit import rdBase
    rdBase.DisableLog('rdApp.error')
except ImportError:
    print("Erreur : RDKit n'est pas installé.")
    print("Veuillez l'installer avec : pip install rdkit")
    exit()


In [None]:
# --- Configuration AMÉLIORÉE ---
BATCH_SIZE = 32
BLOCK_SIZE = 128
MAX_ITERS = 15000  # Augmenté
EVAL_INTERVAL = 500
LEARNING_RATE = 1e-4  # Réduit
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
N_EMBD = 256  # Augmenté
N_HEAD = 8    # Augmenté
N_LAYER = 8   # Augmenté
DROPOUT = 0.1
CONDITION_DIM = 3

# Fichiers
DATA_FILE = 's_100_str_1M_fixed.txt'
VOCAB_FILE = 'vocab_dataset.json'
DATA_CACHE_FILE = 'data_cache_categorical_improved.pt'

# Checkpoints
CHECKPOINT_DIR = 'checkpoints'
CHECKPOINT_FILE = os.path.join(CHECKPOINT_DIR, 'cond_gpt_categorical_improved.pth')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Configuration : Périphérique={DEVICE}, Batch={BATCH_SIZE}, Contexte={BLOCK_SIZE}")
torch.manual_seed(1337)

In [None]:
# ------------------ UTILITAIRE LECTURE SMILES ------------------
def yield_smiles_from_file(filepath):
    """Rend chaque SMILES du fichier"""
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith('<') and '><' in line:
                for smi in line.strip().strip('<>').split('><'):
                    smi = smi.strip()
                    if smi:
                        yield smi
            else:
                yield line


In [None]:
# --- PARTIE 2 : CONSTRUCTION DU VOCABULAIRE AMÉLIORÉE ---

print(f"Construction du vocabulaire à partir de '{DATA_FILE}'...")

PAD_TOKEN = '<pad>'
START_TOKEN = '<start>'
END_TOKEN = '<end>'

# Construction du vocabulaire avec normalisation
char_set = set()
n_seen = 0
valid_smiles_count = 0

for smiles in yield_smiles_from_file(DATA_FILE):
    # Pré-filtrage avec RDKit
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        try:
            canon_smiles = Chem.MolToSmiles(mol, canonical=True)
            char_set.update(list(canon_smiles))
            valid_smiles_count += 1
        except:
            continue
    n_seen += 1

# Vocabulaire
special_tokens = [PAD_TOKEN, START_TOKEN, END_TOKEN]
vocabulary = special_tokens + sorted(set(char_set))

# Dictionnaires
stoi = { ch:i for i,ch in enumerate(vocabulary) }
itos = { i:ch for i,ch in enumerate(vocabulary) }
vocab_size = len(vocabulary)

# Sauvegarde
with open(VOCAB_FILE, 'w', encoding='utf-8') as f:
    json.dump({'stoi': stoi, 'itos': itos}, f, indent=2)

print(f"Taille vocabulaire : {vocab_size} | SMILES valides: {valid_smiles_count}/{n_seen}")

# Fonctions encode/decode
encode = lambda s: [stoi[c] for c in s if c in stoi]
decode = lambda l: ''.join([itos[i] for i in l if i in itos])

# Test
print("\n=== TEST VOCABULAIRE ===")
test_smiles = "CCO"
encoded = encode(test_smiles)
decoded = decode(encoded)
print(f"Test: {test_smiles} -> Encoded: {encoded} -> Decoded: {decoded} -> Match: {test_smiles == decoded}")

In [None]:
# --- PARTIE 3 : PRÉPARATION DES DONNÉES AMÉLIORÉE ---

LOGP_BINS = [0.0, 3.0, 5.0]
MW_BINS   = [250.0, 480.0, 650.0]
HBD_BINS  = [0.0, 1.0, 2.0, 3.0]

def get_category(value, bins):
    for i, upper_bound in enumerate(bins):
        if value <= upper_bound:
            return float(i)
    return float(len(bins))

def load_and_process_data_improved(filepath, stoi, max_len=BLOCK_SIZE, cache_file=DATA_CACHE_FILE):
    """Version améliorée avec normalisation des SMILES"""
    if os.path.exists(cache_file):
        print(f"Chargement des données depuis le cache '{cache_file}'...")
        data = torch.load(cache_file)
        if isinstance(data, list) and len(data) > 0:
            print(f"Données chargées ({len(data)} exemples).")
            return data
        else:
            print("Cache vide/corrompu, re-traitement...")

    print("Traitement des données SMILES avec normalisation...")
    data_processed = []
    pad_idx = stoi[PAD_TOKEN]
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    try:
        for i, smiles in enumerate(tqdm(yield_smiles_from_file(filepath), desc="Traitement")):
            # Vérification RDKit
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue

            try:
                # Normalisation
                canon_smiles = Chem.MolToSmiles(mol, canonical=True)
                
                # Vérification longueur et caractères
                if len(canon_smiles) > max_len - 2 or not all(c in stoi for c in canon_smiles):
                    continue

                # Calcul propriétés
                logp = Descriptors.MolLogP(mol)
                mw   = Descriptors.MolWt(mol)
                hbd  = Descriptors.NumHDonors(mol)

                logp_cat = get_category(logp, LOGP_BINS)
                mw_cat   = get_category(mw, MW_BINS)
                hbd_cat  = get_category(hbd, HBD_BINS)

                condition_vector = torch.tensor([logp_cat, mw_cat, hbd_cat], dtype=torch.float32)

                # Encodage
                token_ids = [start_idx] + encode(canon_smiles) + [end_idx]
                seq_len = len(token_ids)
                
                x = torch.full((max_len,), pad_idx, dtype=torch.long)
                y = torch.full((max_len,), pad_idx, dtype=torch.long)
                x[:seq_len] = torch.tensor(token_ids, dtype=torch.long)
                y[:seq_len-1] = torch.tensor(token_ids[1:], dtype=torch.long)

                data_processed.append((x, y, condition_vector))

            except Exception as e:
                continue

            if i > 0 and i % 50000 == 0:
                gc.collect()

    except FileNotFoundError:
        print(f"ERREUR: Fichier '{filepath}' introuvable.")
        exit()

    print(f"\nNombre total de molécules valides : {len(data_processed)}")
    
    if len(data_processed) == 0:
        print("⚠️ Aucun exemple valide.")
        return []

    print(f"Sauvegarde dans le cache '{cache_file}'...")
    torch.save(data_processed, cache_file)
    return data_processed

class SMILESDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


In [None]:
# --- PARTIE 4 : ARCHITECTURE DU MODÈLE (INCHANGÉE MAIS AVEC CONFIG AMÉLIORÉE) ---

@dataclass
class GPTConfig:
    block_size: int = BLOCK_SIZE
    vocab_size: int = vocab_size
    n_layer: int = N_LAYER
    n_head: int = N_HEAD
    n_embd: int = N_EMBD
    dropout: float = DROPOUT
    condition_dim: int = CONDITION_DIM

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class ConditionalDrugGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))

        self.condition_projector = nn.Sequential(
            nn.Linear(config.condition_dim, config.n_embd),
            nn.ReLU()
        )

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, conditions=None):
        device = idx.device
        B, T = idx.shape
        assert T <= self.config.block_size, f"Séquence trop longue: {T}"
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)

        assert conditions is not None, "Conditions requises !"
        cond_emb = self.condition_projector(conditions)

        x = self.transformer.drop(tok_emb + pos_emb + cond_emb.unsqueeze(1))

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=stoi[PAD_TOKEN]
            )
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None

        return logits, loss

In [None]:
# --- PARTIE 5 : FONCTIONS UTILITAIRES AMÉLIORÉES ---

def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, filepath):
    print(f"Sauvegarde du checkpoint dans {filepath}...")
    torch.save({
        'iter_num': iter_num,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'config': config,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    if not os.path.exists(filepath):
        print("Aucun checkpoint trouvé. Démarrage d'un nouvel entraînement.")
        return 0, float('inf')

    print(f"Chargement du checkpoint depuis {filepath}...")
    checkpoint = torch.load(filepath, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']
    print(f"Reprise à l'itération {iter_num} (meilleure perte val: {best_val_loss:.4f})")
    return iter_num, best_val_loss

@torch.no_grad()
def estimate_loss(model, train_loader, val_loader, eval_iters=EVAL_ITERS):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        loader = train_loader if split == 'train' else val_loader
        losses = torch.zeros(eval_iters)
        loader_iter = iter(loader)
        for k in range(eval_iters):
            try:
                x, y, c = next(loader_iter)
            except StopIteration:
                loader_iter = iter(loader)
                x, y, c = next(loader_iter)

            x, y, c = x.to(DEVICE), y.to(DEVICE), c.to(DEVICE)
            logits, loss = model(x, y, c)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
# --- PARTIE 6 : GÉNÉRATION AMÉLIORÉE ---

def validate_and_repair_smiles(smiles):
    """Tente de réparer les SMILES invalides"""
    if not smiles or smiles == "[VIDE]":
        return None
        
    # Essayer directement
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, canonical=True)
    
    # Tentatives de réparation
    try:
        # Fermer les parenthèses
        open_count = smiles.count('(')
        close_count = smiles.count(')')
        if open_count > close_count:
            smiles += ')' * (open_count - close_count)
        
        # Fermer les cycles
        ring_numbers = set()
        for char in smiles:
            if char.isdigit():
                ring_numbers.add(char)
        
        for ring_num in ring_numbers:
            if smiles.count(ring_num) % 2 != 0:
                smiles += ring_num
        
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            return Chem.MolToSmiles(mol, canonical=True)
    except:
        pass
        
    return None

@torch.no_grad()
def generate_conditional_improved(model, condition_tensor, stoi, itos, max_new_tokens=80, temperature=0.8, top_k=12):
    """Version améliorée de la génération"""
    model.eval()
    start_idx = stoi[START_TOKEN]
    end_idx = stoi[END_TOKEN]

    idx = torch.tensor([[start_idx]], dtype=torch.long, device=DEVICE)
    condition_tensor = condition_tensor.to(DEVICE)

    for step in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= BLOCK_SIZE else idx[:, -BLOCK_SIZE:]
        
        logits, _ = model(idx_cond, conditions=condition_tensor)
        logits = logits[:, -1, :] / max(temperature, 0.1)  # Éviter division par zéro
        
        # Top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        
        probs = F.softmax(logits, dim=-1)
        
        # Échantillonnage avec contrainte pour éviter les tokens improbables
        if torch.isnan(probs).any() or torch.max(probs) < 0.01:
            idx_next = torch.tensor([[stoi['C']]], device=DEVICE)  # Fallback
        else:
            idx_next = torch.multinomial(probs, num_samples=1)
        
        # Arrêt si token de fin
        if idx_next.item() == end_idx:
            break
            
        idx = torch.cat((idx, idx_next), dim=1)
        
        # Arrêt précoce si séquence trop longue
        if idx.size(1) >= max_new_tokens:
            break

    # Décodage
    generated_tokens = idx[0].tolist()
    tokens_to_decode = generated_tokens[1:]  # Exclure <start>
    
    # Trouver le premier <end>
    if end_idx in tokens_to_decode:
        end_pos = tokens_to_decode.index(end_idx)
        tokens_to_decode = tokens_to_decode[:end_pos]
    
    generated_smiles = decode(tokens_to_decode)
    
    return generated_smiles if generated_smiles else "[VIDE]"

def generate_with_retry(model, condition_tensor, stoi, itos, max_retries=3):
    """Génère avec plusieurs tentatives"""
    best_smiles = None
    best_length = 0
    
    for attempt in range(max_retries):
        # Ajuster la température progressivement
        temp = 0.7 + attempt * 0.2
        top_k = 10 + attempt * 5
        
        smiles = generate_conditional_improved(
            model, condition_tensor, stoi, itos, 
            temperature=temp,
            top_k=top_k
        )
        
        # Tenter réparation
        repaired = validate_and_repair_smiles(smiles)
        if repaired and Chem.MolFromSmiles(repaired):
            # Préférer les SMILES de longueur raisonnable
            if 10 <= len(repaired) <= 60:
                return repaired
            elif best_smiles is None or abs(len(repaired) - 35) < abs(best_length - 35):
                best_smiles = repaired
                best_length = len(repaired)
    
    return best_smiles if best_smiles else "[VIDE]"

def check_mol_3_props(smiles):
    """Vérifie les 3 propriétés réelles"""
    if not smiles or smiles == "[VIDE]":
        return "Vide", 0.0, 0.0, 0
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return "Invalide", 0.0, 0.0, 0
    try:
        logp = Descriptors.MolLogP(mol)
        mw = Descriptors.MolWt(mol)
        hbd = Descriptors.NumHDonors(mol)
        return "Valide", logp, mw, hbd
    except:
        return "Erreur", 0.0, 0.0, 0

def evaluate_generation_quality(generated_smiles, target_conditions):
    """Évalue la qualité des générations"""
    results = {
        'valid': 0,
        'matching_conditions': 0,
        'unique': 0,
        'avg_length': 0
    }
    
    unique_smiles = set()
    valid_smiles = []
    
    for smi in generated_smiles:
        status, logp, mw, hbd = check_mol_3_props(smi)
        
        if status == "Valide":
            results['valid'] += 1
            unique_smiles.add(smi)
            valid_smiles.append(smi)
            results['avg_length'] += len(smi)
            
            # Vérifier correspondance conditions
            logp_cat = get_category(logp, LOGP_BINS)
            mw_cat = get_category(mw, MW_BINS)
            hbd_cat = get_category(hbd, HBD_BINS)
            
            if (logp_cat == target_conditions[0] and 
                mw_cat == target_conditions[1] and 
                hbd_cat == target_conditions[2]):
                results['matching_conditions'] += 1
    
    results['unique'] = len(unique_smiles)
    if results['valid'] > 0:
        results['avg_length'] /= results['valid']
    
    return results, valid_smiles

In [None]:
# --- PARTIE 7 : SCRIPT PRINCIPAL AMÉLIORÉ ---

if __name__ == "__main__":

    # 1. Chargement des données amélioré
    full_data = load_and_process_data_improved(DATA_FILE, stoi, cache_file=DATA_CACHE_FILE)
    if len(full_data) == 0:
        print("Arrêt : aucun échantillon valide.")
        exit()

    # 2. Équilibrage des classes
    print("\n⚖️ Calcul des poids d'échantillonnage...")
    combo_counts = Counter([tuple(sample[2].tolist()) for sample in full_data])
    total_samples = len(full_data)
    num_classes = len(combo_counts)

    weights = []
    for _, _, cond in full_data:
        key = tuple(cond.tolist())
        weights.append(total_samples / (num_classes * combo_counts[key]))

    print(f"✅ Combinaisons de classes : {num_classes}")
    print(f"Exemple de poids : {list(weights)[:5]}")

    # 3. Dataset et DataLoaders
    dataset = SMILESDataset(full_data)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_data, val_data = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_indices = train_data.indices
    train_weights = [weights[i] for i in train_indices]
    sampler = WeightedRandomSampler(train_weights, num_samples=len(train_indices), replacement=True)

    pin_mem = True if DEVICE == 'cuda' else False
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, sampler=sampler, 
                             num_workers=0, pin_memory=pin_mem)
    val_loader   = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, 
                             num_workers=0, pin_memory=pin_mem)

    print("✅ Échantillonnage équilibré activé.")

    # 4. Modèle avec configuration améliorée
    config = GPTConfig(vocab_size=vocab_size)
    model = ConditionalDrugGPT(config).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
    
    param_count = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"Nombre de paramètres : {param_count:.2f} M")

    # 5. Checkpoint
    start_iter, best_val_loss = load_checkpoint(CHECKPOINT_FILE, model, optimizer)

    # 6. Entraînement amélioré
    print(f"🚀 Début de l'entraînement sur {DEVICE}...")
    start_time = time.time()
    train_iter = iter(train_loader)

    for iter_num in range(start_iter, MAX_ITERS):
        try:
            xb, yb, cb = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            xb, yb, cb = next(train_iter)

        xb, yb, cb = xb.to(DEVICE), yb.to(DEVICE), cb.to(DEVICE)
        logits, loss = model(xb, targets=yb, conditions=cb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if iter_num % 1000 == 0:
    results3 = test_generation_complete("[LogP(2), MW(2), HBD(1)]", [2.0, 2.0, 1.0])
    
    # Résumé final
    print("\n" + "="*60)
    print("📈 RÉSUMÉ FINAL DES PERFORMANCES")
    print("="*60)
    print(f"Validité moyenne: {(results1['valid'] + results2['valid'] + results3['valid']) / 90 * 100:.1f}%")
    print(f"Correspondance moyenne: {(results1['matching_conditions'] + results2['matching_conditions'] + results3['matching_conditions']) / 90 * 100:.1f}%")
    print(f"Unicité moyenne: {(results1['unique'] + results2['unique'] + results3['unique']) / 90 * 100:.1f}%")

    print("\n✅ Script terminé avec succès !")

In [None]:
# ---  Copier ici ton texte de log ---
log_text = """
Étape 500: train 1.2334, val 1.2766, temps 67.2s
Étape 1000: train 0.9699, val 1.0209, temps 132.1s
Étape 1500: train 0.8488, val 0.9194, temps 195.2s
Étape 2000: train 0.7945, val 0.8735, temps 261.1s
Étape 2500: train 0.7526, val 0.8458, temps 324.2s
Étape 3000: train 0.7040, val 0.8088, temps 389.5s
Étape 3500: train 0.6828, val 0.7919, temps 452.7s
Étape 4000: train 0.6517, val 0.7646, temps 517.5s
Étape 4500: train 0.6294, val 0.7514, temps 580.6s
Étape 5000: train 0.6255, val 0.7525, temps 645.6s
Étape 5500: train 0.6045, val 0.7314, temps 708.7s
Étape 6000: train 0.5819, val 0.7215, temps 773.1s
Étape 6500: train 0.5725, val 0.7142, temps 836.3s
Étape 7000: train 0.5620, val 0.7078, temps 900.9s
Étape 7500: train 0.5445, val 0.6998, temps 964.2s
Étape 8000: train 0.5407, val 0.6941, temps 1029.0s
Étape 8500: train 0.5390, val 0.6958, temps 1092.5s
Étape 9000: train 0.5285, val 0.6848, temps 1156.7s
Étape 9500: train 0.5146, val 0.6801, temps 1220.0s
Étape 10000: train 0.5198, val 0.6785, temps 1284.6s
Étape 10500: train 0.4981, val 0.6715, temps 1348.0s
Étape 11000: train 0.4952, val 0.6720, temps 1412.5s
Étape 11500: train 0.4886, val 0.6643, temps 1475.8s
Étape 12000: train 0.4793, val 0.6595, temps 1540.8s
Étape 12500: train 0.4781, val 0.6618, temps 1604.3s
Étape 13000: train 0.4640, val 0.6522, temps 1669.2s
Étape 13500: train 0.4686, val 0.6523, temps 1732.5s
Étape 14000: train 0.4658, val 0.6482, temps 1797.6s
Étape 14500: train 0.4593, val 0.6496, temps 1860.8s
Étape 14999: train 0.4441, val 0.6414, temps 1924.1s
"""

# --- Extraire les données avec regex ---
pattern = r"Étape (\d+): train ([0-9.]+), val ([0-9.]+)"
matches = re.findall(pattern, log_text)

steps = [int(m[0]) for m in matches]
train_losses = [float(m[1]) for m in matches]
val_losses = [float(m[2]) for m in matches]

# ---Créer un DataFrame pour plus de clarté ---
df = pd.DataFrame({
    "step": steps,
    "train_loss": train_losses,
    "val_loss": val_losses
})
print(df.head())

# ---  Plot ---
plt.figure(figsize=(8,5))
plt.plot(df["step"], df["train_loss"], label="Train", marker='o', linewidth=2)
plt.plot(df["step"], df["val_loss"], label="Validation", marker='s', linewidth=2)
plt.title("Courbe d'apprentissage (train vs validation)")
plt.xlabel("Itérations")
plt.ylabel("Perte moyenne (cross-entropy)")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig("train_val_loss.png", dpi=150)
plt.show()


In [None]:
def evaluate_model_with_metrics(model, stoi, itos, dataset, conditions_list, n_samples=100):
    """Évalue le modèle avec plusieurs métriques et crée une table + graphiques."""
    
    results_all = []
    seen_smiles = {Chem.MolToSmiles(Chem.MolFromSmiles(decode(x.tolist()[1:-1]))) 
                   for x, _, _ in dataset if Chem.MolFromSmiles(decode(x.tolist()[1:-1]))}

    for cond_label, cond_vec in conditions_list:
        cond_tensor = torch.tensor(cond_vec, dtype=torch.float32).unsqueeze(0)
        generated = [generate_with_retry(model, cond_tensor, stoi, itos, max_retries=3) for _ in range(n_samples)]
        
        valid = [s for s in generated if Chem.MolFromSmiles(s)]
        unique = set(valid)
        novel = [s for s in unique if s not in seen_smiles]
        
        # Calcul respect conditions
        matches = 0
        for s in valid:
            status, logp, mw, hbd = check_mol_3_props(s)
            if (get_category(logp, LOGP_BINS) == cond_vec[0] and
                get_category(mw, MW_BINS) == cond_vec[1] and
                get_category(hbd, HBD_BINS) == cond_vec[2]):
                matches += 1

        results_all.append({
            "Condition": cond_label,
            "Validité (%)": len(valid)/n_samples*100,
            "Unicité (%)": len(unique)/n_samples*100,
            "Nouveauté (%)": len(novel)/n_samples*100,
            "Respect conditions (%)": matches/n_samples*100
        })

    df = pd.DataFrame(results_all)
    print("\n📊 Tableau récapitulatif :")
    print(df.round(2))

    # Graphiques
    plt.figure(figsize=(8,5))
    df.set_index("Condition")[["Validité (%)", "Unicité (%)", "Nouveauté (%)", "Respect conditions (%)"]].plot(kind="bar")
    plt.title("Évaluation du modèle de génération moléculaire")
    plt.ylabel("Pourcentage (%)")
    plt.xticks(rotation=30, ha='right')
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

    return df


In [None]:
conditions_to_test = [
    ("[LogP(1), MW(1), HBD(3)]", [1.0, 1.0, 3.0]),
    ("[LogP(3), MW(3), HBD(0)]", [3.0, 3.0, 0.0]),
    ("[LogP(2), MW(2), HBD(1)]", [2.0, 2.0, 1.0]),
]

df_results = evaluate_model_with_metrics(model, stoi, itos, dataset, conditions_to_test, n_samples=1000)
