#Генеративная модель

In [None]:
import matplotlib
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, SanitizeFlags
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
from rdkit.DataStructs import TanimotoSimilarity
import random
import warnings
from rdkit import RDLogger
from tqdm import tqdm
from functools import lru_cache
import multiprocessing as mp
from rdkit.Chem import BRICS
import matplotlib.pyplot as plt
import dill
import logging
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import os

###Установка параметров алгоритма

In [None]:
logging.basicConfig(filename='ga_errors.log', level=logging.INFO)
logging.basicConfig(filename='ga_debug.log', level=logging.DEBUG)


POPULATION_SIZE = 800
GENERATIONS = 150
MUTATION_RATE = 0.8
CROSSOVER_RATE = 0.5
ELITE_SIZE = 50
TOURNAMENT_SIZE = 30
N_CORES = mp.cpu_count()


FP_RADIUS = 2
FP_BITS = 2048

###Кэширования для быстрой обработки

In [None]:
@lru_cache(maxsize=10_000)
def get_cached_mol(smiles):
    return Chem.MolFromSmiles(smiles)

@lru_cache(maxsize=10_000)
def get_cached_fingerprint(smiles):
    mol = get_cached_mol(smiles)
    return GetMorganFingerprintAsBitVect(mol, FP_RADIUS, FP_BITS) if mol else None

###Вспомогательные вычисления для основной модели

In [None]:
warnings.filterwarnings("ignore")
RDLogger.DisableLog('rdApp.*')

def parallel_calculate(args):
    try:
        smiles, target_fp = args
        mol = get_cached_mol(smiles)
        if not mol:
            return 0.0
        fp = get_cached_fingerprint(smiles)
        similarity = TanimotoSimilarity(fp, target_fp)
        logP = Descriptors.MolLogP(mol)
        return 0.6 * similarity + 0.4 * (1 - abs(logP - 1.5)/2)
    except Exception as e:
        print(f"Ошибка в parallel_calculate: {e}")
        return 0.0

###Функции для основного класса

In [None]:
def calculate_fitness_parallel(population, target_fp):
    ctx = mp.get_context('spawn')
    ctx.reducer = dill.Reduce  # Используем dill для сериализации
    with ctx.Pool(N_CORES) as pool:
        args = [(smiles, target_fp) for smiles in population]
        return list(pool.imap(parallel_calculate, args))

def sanitize_mol(mol):
    try:
        flags = SanitizeFlags.SANITIZE_ALL
        Chem.SanitizeMol(mol, flags)
        return mol
    except:
        return None


def validate_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if not mol:
        return False
    try:
        for atom in mol.GetAtoms():
            if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
                return False
        return True
    except:
        return False

def combine_with_scaffold(mol, scaffold_frag):
    try:
        combined = Chem.CombineMols(mol, scaffold_frag)
        ed_combined = Chem.EditableMol(combined)
        anchor1 = random.randint(0, mol.GetNumAtoms()-1)
        anchor2 = random.randint(mol.GetNumAtoms(), combined.GetNumAtoms()-1)
        ed_combined.AddBond(anchor1, anchor2, Chem.BondType.SINGLE)
        new_mol = ed_combined.GetMol()
        new_mol = sanitize_mol(new_mol)
        return new_mol
    except:
        return mol





def crossover(smiles1, smiles2):
    try:
        mol1, mol2 = Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles(smiles2)
        if not mol1 or not mol2:
            return (smiles1, smiles2)

        frags1 = list(BRICS.BRICSDecompose(mol1, minFragmentSize=4))  # Минимум 4 атома
        frags2 = list(BRICS.BRICSDecompose(mol2, minFragmentSize=4))

        if not frags1 or not frags2:
            return (smiles1, smiles2)

        frag1 = random.choice(frags1)
        frag2 = random.choice(frags2)
        combined = BRICS.CombineFragments(frag1, frag2)
        return (Chem.MolToSmiles(combined),)
    except:
        return (smiles1, smiles2)


def rank_selection(population, fitness):
    ranked = sorted(zip(population, fitness), key=lambda x: x[1], reverse=True)
    return [x[0] for x in ranked[:int(0.2*POPULATION_SIZE)]]



def tournament_selection(population, fitness):
    candidates = random.sample(list(zip(population, fitness)), 5)
    return max(candidates, key=lambda x: x[1])[0]

##Основной класс работы

In [None]:
class MolecularOptimizer:
    def __init__(self, target_smiles, population_size=2000):
        self.target_mol = Chem.MolFromSmiles(target_smiles)
        if not self.target_mol:
            raise ValueError(f"Invalid target SMILES: {target_smiles}")

        self.target_fp = GetMorganFingerprintAsBitVect(self.target_mol, FP_RADIUS, FP_BITS)
        self.target_frags = list(BRICS.BRICSDecompose(self.target_mol))
        self.population_size = population_size
        self.population = []
        self.fitness_history = []
        self.params = {
            'generations': GENERATIONS,
            'mutation_rate': MUTATION_RATE,
            'crossover_rate': CROSSOVER_RATE
        }

###Функция сохранения модели

In [None]:
    def save(self, filename):
        import pickle
        with open(filename, 'wb') as f:
            pickle.dump({
                'target_fp': self.target_fp,
                'population': self.population,
                'fitness_history': self.fitness_history,
                'params': self.params
            }, f)

    @classmethod
    def load(cls, filename):
        import pickle
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        instance = cls.__new__(cls)
        instance.__dict__.update(data)
        return instance

###Функция подсчета Fitness rate

In [None]:
    def calculate_fitness(self, smiles):
        if not validate_molecule(smiles):
            return 0.0

        mol = Chem.MolFromSmiles(smiles)
        fp = GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
        similarity = TanimotoSimilarity(fp, self.target_fp)

        h_donors = Descriptors.NumHDonors(mol)
        h_acceptors = Descriptors.NumHAcceptors(mol)
        aromatic_rings = Descriptors.NumAromaticRings(mol)

        h_bond_bonus = 0.1 * min(h_donors + h_acceptors, 4)  # До 4 взаимодействий
        pi_pi_bonus = 0.15 * aromatic_rings if aromatic_rings >= 2 else 0

        mw_penalty = 0.02 * (Descriptors.MolWt(mol) / 500 if Descriptors.MolWt(mol) > 500 else 0)

        return 0.7 * similarity + 0.2 * (h_bond_bonus + pi_pi_bonus) - mw_penalty

###Эволюционная оптимизация

In [None]:
    def optimize(self, initial_data_path):
        self.population = random.sample(smiles_list, self.population_size)

        add_counter = 0
        add_interval = random.randint(2, 3)
        best_fitness_history = []
        current_mutation_rate = self.params['mutation_rate']  # Используем параметры класса


        for generation in tqdm(range(self.params['generations']), desc="Evolution"):

            self.population = [s for s in self.population if validate_molecule(s)]
            fitness = [self.calculate_fitness(s) for s in self.population]
            current_best = max(fitness) if fitness else 0.0
            self.fitness_history.append(current_best)


            if len(self.fitness_history) > 5 and (self.fitness_history[-1] - self.fitness_history[-5]) < 0.01:
                current_mutation_rate = 0.8  # Резкое увеличение при застое
            else:
                current_mutation_rate = self.params['mutation_rate']


            if len(self.population) < 0.5 * self.population_size:
                new_samples = random.sample(smiles_list, self.population_size // 3)
                self.population = list(set(self.population + new_samples))[:self.population_size]


            if current_best < 0.5:
                current_mutation_rate = min(0.6, current_mutation_rate + 0.1)
            else:
                current_mutation_rate = max(0.2, current_mutation_rate - 0.05)


            elite_indices = np.argsort(fitness)[-ELITE_SIZE:]
            elites = [self.population[i] for i in elite_indices]


            offspring = []
            while len(offspring) < self.population_size - ELITE_SIZE:
                parent1 = tournament_selection(self.population, fitness)
                parent2 = tournament_selection(self.population, fitness)
                child1, child2 = crossover(parent1, parent2)
                offspring.extend([child1, child2])


            mutated_elites = [self.mutate(s) if random.random() < 0.8 else s for s in elites]
            offspring = mutated_elites + offspring


            self.population = list({s for s in offspring if validate_molecule(s)})[:self.population_size]


            add_counter += 1

            if generation % 3 == 0:
                target_frags = []
                for frag in self.target_frags:
                    try:

                        if isinstance(frag, Chem.rdchem.Mol):
                            smiles = Chem.MolToSmiles(frag)
                            if validate_molecule(smiles):
                                target_frags.append(smiles)
                        else:
                            logging.warning("Фрагмент не является объектом Mol")
                    except Exception as e:
                        logging.error(f"Ошибка преобразования фрагмента: {e}")
                        continue

                if target_frags:

                    leader_smiles = max(zip(self.population, fitness), key=lambda x: x[1])[0]
                    leader_fp = get_cached_fingerprint(leader_smiles)
                    weights = [TanimotoSimilarity(get_cached_fingerprint(s), leader_fp) for s in target_frags]
                    new_samples = random.choices(target_frags, weights=weights, k=int(0.1 * self.population_size))
                    self.population = list(set(self.population + new_samples))[:self.population_size]

            print(f"Gen {generation}: Best {current_best:.2f} Diversity {len(self.population)}")

            top_molecules = sorted(zip(self.population, fitness), key=lambda x: x[1], reverse=True)[:5]
            print("Топ-5 молекул:")
            for idx, (sm, fit) in enumerate(top_molecules):
                print(f"  {idx + 1}. {sm} | Fitness: {fit:.2f}")


        if len(self.fitness_history) > 0:
            plt.plot(self.fitness_history)
            plt.savefig('fitness_progress.png')
        return self.population

        top_mols = [Chem.MolFromSmiles(sm) for sm, _ in top_molecules[:5]]
        img = Draw.MolsToGridImage(top_mols, molsPerRow=5, subImgSize=(300, 300))
        img.save(f'gen_{generation}_top5.png')

###Функция мутаций

In [None]:
    def mutate(self, smiles):  # self добавлен как параметр
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return smiles

        new_mol = Chem.RWMol(mol)
        try:
            mutation_type = random.choice([
                "replace_atom", "add_bond", "remove_atom",
                "add_ring", "add_functional_group", "scaffold_hopping",
                "change_bond_type"])

            if mutation_type == "replace_atom":
                atom_idx = random.choice(range(new_mol.GetNumAtoms()))
                atom = new_mol.GetAtomWithIdx(atom_idx)
                new_element = random.choice(['C', 'N', 'O'])
                new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_element)

            elif mutation_type == "add_bond":

                atoms = [atom.GetIdx() for atom in new_mol.GetAtoms()]
                if len(atoms) >= 2:
                    pair = random.sample(atoms, 2)
                    new_mol.AddBond(pair[0], pair[1], Chem.BondType.SINGLE)
            elif mutation_type == "remove_atom":

                if new_mol.GetNumAtoms() > 1:
                    atom_idx = random.choice(range(new_mol.GetNumAtoms()))
                    new_mol.RemoveAtom(atom_idx)

            elif mutation_type == "add_ring":

                atoms = [atom for atom in new_mol.GetAtoms()]
                if not atoms:
                    return smiles
                anchor_atom = random.choice(atoms)
                anchor_idx = anchor_atom.GetIdx()


                ring = Chem.MolFromSmiles("C1=CC=CC=C1")  # Бензол
                if not ring:
                    return smiles


                combined = Chem.CombineMols(new_mol, ring)
                ed_combined = Chem.EditableMol(combined)
                ed_combined.AddBond(anchor_idx, len(new_mol.GetAtoms()),
                                    Chem.BondType.SINGLE)  # Связь между атомом и кольцом


                modified_mol = ed_combined.GetMol()
                modified_mol = sanitize_mol(modified_mol)
                if modified_mol and validate_molecule(Chem.MolToSmiles(modified_mol)):
                    return Chem.MolToSmiles(modified_mol)  # Возврат здесь


            elif mutation_type == "add_functional_group":
                if new_mol.GetNumAtoms() == 0:
                    return smiles
                atom_idx = random.choice(range(new_mol.GetNumAtoms()))
                group = random.choice(["O", "N", "F", "Cl"])  # Простые атомы вместо SMARTS
                new_atom = Chem.Atom(group)
                new_mol.AddAtom(new_atom)
                new_mol.AddBond(atom_idx, new_mol.GetNumAtoms() - 1, Chem.BondType.SINGLE)



            elif mutation_type == "scaffold_hopping":
                if new_mol.GetNumAtoms() > 5 and hasattr(self, 'target_frags') and len(self.target_frags) > 0:
                    scaffold_frag = random.choice(self.target_frags)
                    modified_mol = combine_with_scaffold(new_mol, scaffold_frag)
                    if modified_mol and validate_molecule(Chem.MolToSmiles(modified_mol)):
                        return Chem.MolToSmiles(modified_mol)

            new_smiles = Chem.MolToSmiles(sanitize_mol(new_mol))
            return new_smiles if validate_molecule(new_smiles) else smiles
        except Exception as e:
            logging.debug(f"Mutation failed for {smiles}: {str(e)}")
            print(f"Ошибка в mutate: {e}")
            return smiles

###Проверка валидности

In [None]:
V = 0
G = res['smiles'].count()

for i in res['smiles']:
    mol = Chem.MolFromSmiles(i)
    if mol is not None:
        V += 1
Validity = V / G * 100
Validity

###Проверка новизны

In [None]:
N = V
for i in res['smiles']:
    if i in real['CC1(C)C(=O)NC1S(=O)(=O)C1=CC=CC=C1']:
        N -= 1

Novelty = N / V * 100
Novelty

###Проверка уникальности

In [None]:
U = res['smiles'].drop_duplicates().count()

Unique = U/V*100
Unique

#Метрики

In [None]:
Target coformers

smiles_list = [ i for i in res['smiles']]
score = [calculate_sa_score(i) for i in smiles_list]
score = [i for i in score if i <= 3.15]

S = len(score)
Target = S / G * 100

##Анализ механических свойств

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors

def validate_molecule(smiles, verbose=False):

    try:

        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False, "Invalid SMILES syntax", None


        if mol.GetNumAtoms() == 0:
            return False, "No atoms in molecule", None


        allowed_elements = {1, 6, 7, 8, 9, 15, 16, 17}  # H,C,N,O,F,P,S,Cl
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() not in allowed_elements:
                return False, f"Disallowed element: {atom.GetSymbol()}", None

        try:
            Chem.SanitizeMol(mol)
        except ValueError as e:
            return False, f"Valence error: {str(e)}", None


        if Descriptors.MolWt(mol) > 1000:
            if verbose:
                print("Warning: Molecular weight > 1000")

        return True, None, mol

    except Exception as e:
        return False, f"Validation error: {str(e)}", None



test_smiles = [
    "CCO",          # Валидный этанол
    "C1=CC=CC=C1",  # Валидный бензол
    "C(C)(C)(C)C",  # Невалидная валентность
    "InvalidSMILES",# Неправильный SMILES
    "[Au]",         # Запрещенный элемент
    "[H][H]"        # Просто водород
]

for smi in test_smiles:
    is_valid, reason, mol = validate_molecule(smi, verbose=True)
    print(f"{smi[:20]:<20} | Valid: {is_valid} | Reason: {reason or '-'}")

###SA score

In [None]:
def calculate_sa_score(smiles):

    try:

        if not mol:
            raise ValueError("Invalid SMILES")


        fp = AllChem.GetMorganFingerprint(mol, radius=2)  # radius=2 стандартно для SA score


        frag_contrib = 1 - sum(fp.GetNonzeroElements().values()) / (100 * mol.GetNumAtoms())
        ring_penalty = sum(0.5 for ring in mol.GetRingInfo().AtomRings() if len(ring) > 6)
        mw_penalty = np.log10(Descriptors.MolWt(mol)) / 10


        sa_score = min(10, max(1, 3 * frag_contrib + ring_penalty + mw_penalty))
        return round(sa_score, 2)

    except Exception as e:
        return 0