In [1]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
import os
from IPython.display import display
from sklearn.preprocessing import MinMaxScaler

In [2]:
def is_valid_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        return mol is not None
    except:
        return False

In [3]:
MAX_LEN = 128
BATCH_SIZE = 32
EMB_DIM = 256
LATENT_DIM = 128
N_HEADS = 8
FF_DIM = 512
NUM_LAYERS = 4
VOCAB_SPECIAL = ['<pad>', '<bos>', '<eos>', '<unk>']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используем устройство: {device}")

Используем устройство: cuda


In [4]:
df = pd.read_csv("polymer_bigdata_merged.csv")
df = df.dropna(subset=["polymer_smiles"])

# Фильтруем по длине SMILES, чтобы убрать мусор/очень длинные строки
df = df[df["polymer_smiles"].str.len() > 2]
df = df[df["polymer_smiles"].str.len() < MAX_LEN-2]  # учтём <bos>/<eos>

# Берём только первые 25 000 записей для обучения (или меньше, если данных меньше)
subset_df = df.iloc[:10000].copy()

NUM_FEATURES = [
    "Enthalpy of Polymerization (kJ/mol)",
    "Glass Transition Temperature (K)",
    "Specific Heat Capacity (J {gK}^{-1})",
    "Tensile Strength at break (MPa)",
    "Thermal Decomposition Temperature (K)",
    "Youngs Modulus (GPa)"
]

# Убедимся, что фичи не постоянные (иначе MinMax даст NaN)
for f in NUM_FEATURES:
    if subset_df[f].max() == subset_df[f].min():
        print(f"Warning: feature {f} is constant. Filling with zeros.")
        subset_df[f] = 0.0

In [5]:
scaler = MinMaxScaler()
subset_df[NUM_FEATURES] = scaler.fit_transform(subset_df[NUM_FEATURES])

In [6]:
def tokenize(smiles):
    return list(smiles)

all_tokens = set()
for s in subset_df["polymer_smiles"]:
    all_tokens.update(tokenize(s))

tokens = VOCAB_SPECIAL + sorted(list(all_tokens))
token2idx = {t: i for i, t in enumerate(tokens)}
idx2token = {i: t for t, i in token2idx.items()}
vocab_size = len(token2idx)

In [7]:
class ConditionalSMILESDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe["polymer_smiles"].tolist()
        self.features = dataframe[NUM_FEATURES].values.astype(np.float32)

    def __len__(self):
        return len(self.data)

    def encode_smiles(self, smiles):
        tokens_seq = ['<bos>'] + list(smiles) + ['<eos>']
        token_ids = [token2idx.get(t, token2idx['<unk>']) for t in tokens_seq]
        token_ids = token_ids[:MAX_LEN]
        if len(token_ids) < MAX_LEN:
            token_ids += [token2idx['<pad>']] * (MAX_LEN - len(token_ids))
        return torch.tensor(token_ids)

    def __getitem__(self, idx):
        x = self.encode_smiles(self.data[idx])
        tgt_input = x[:-1]
        tgt_output = x[1:]
        feats = torch.tensor(self.features[idx])
        return tgt_input, tgt_output, feats

In [8]:
class ConditionalTransformerVAE(nn.Module):
    def __init__(self, vocab_size, emb_dim, latent_dim, num_heads, ff_dim, num_layers, max_len, feature_dim):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=token2idx['<pad>'])
        self.pos_emb = nn.Parameter(torch.randn(1, max_len, emb_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.feat_encoder = nn.Linear(feature_dim, latent_dim)
        self.fc_mu = nn.Linear(emb_dim + latent_dim, latent_dim)
        self.fc_logvar = nn.Linear(emb_dim + latent_dim, latent_dim)
        self.decoder_proj = nn.Linear(latent_dim + latent_dim, emb_dim)

        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.out = nn.Linear(emb_dim, vocab_size)

        # Инициализация весов (Xavier) для стабилизации
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src, feats):
        src_mask = (src == token2idx['<pad>']).bool()
        src_emb = self.emb(src) + self.pos_emb[:, :src.size(1)]
        src_enc = self.encoder(src_emb.transpose(0,1), src_key_padding_mask=src_mask)
        cls_token = src_enc[0, :, :]
        feats_emb = self.feat_encoder(feats)
        combined = torch.cat([cls_token, feats_emb], dim=1)
        mu = self.fc_mu(combined)
        logvar = self.fc_logvar(combined)
        # не клэмпим здесь окончательно — клэмпим при вычислении KL
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Защита: клэмпим logvar до разумного максимума перед exp
        logvar = torch.clamp(logvar, min=-10, max=4)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, feats, tgt_inp):
        tgt_mask = (tgt_inp == token2idx['<pad>']).bool()
        z_dec = torch.cat([z, self.feat_encoder(feats)], dim=1).unsqueeze(1).repeat(1, tgt_inp.size(1), 1)
        tgt_emb = self.emb(tgt_inp) + self.pos_emb[:, :tgt_inp.size(1)]
        tgt_seq_mask = nn.Transformer.generate_square_subsequent_mask(tgt_inp.size(1)).to(z.device)
        out = self.decoder(tgt=tgt_emb.transpose(0,1), memory=z_dec.transpose(0,1), tgt_mask=tgt_seq_mask, tgt_key_padding_mask=tgt_mask)
        logits = self.out(out.transpose(0,1))
        return logits

    def forward(self, src, feats, tgt_inp):
        mu, logvar = self.encode(src, feats)
        z = self.reparameterize(mu, logvar)
        logits = self.decode(z, feats, tgt_inp)
        return logits, mu, logvar

In [9]:
def conditional_vae_loss(logits, targets, mu, logvar, beta=0.1, pad_idx=token2idx['<pad>']):
    # Reconstruction
    recon_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=pad_idx)
    # KL: используем клэмпнутый logvar для стабильности
    clamped_logvar = torch.clamp(logvar, min=-10, max=4)
    kl_per_sample = -0.5 * torch.sum(1 + clamped_logvar - mu.pow(2) - clamped_logvar.exp(), dim=1)
    kl_loss = torch.mean(kl_per_sample)
    total_loss = recon_loss + beta * kl_loss
    return total_loss, recon_loss, kl_loss

In [10]:
def generate_smiles_conditional(model, target_feats, num_samples=5, temperature=1.0):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, LATENT_DIM).to(device)
        target_feats = target_feats.to(device)
        z_dec = torch.cat([z, model.feat_encoder(target_feats)], dim=1).unsqueeze(1)
        input_token = torch.tensor([token2idx['<bos>']] * num_samples).unsqueeze(1).to(device)
        sequences = input_token
        finished = torch.zeros(num_samples, dtype=torch.bool).to(device)

        for _ in range(MAX_LEN - 1):
            emb = model.emb(sequences) + model.pos_emb[:, :sequences.size(1)]
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(sequences.size(1)).to(device)
            out = model.decoder(emb.transpose(0,1), z_dec.transpose(0,1), tgt_mask=tgt_mask)
            logits = model.out(out.transpose(0,1))[:, -1] / temperature
            # Защита: если logits содержат NaN — заменим их на очень низкие числа
            if torch.isnan(logits).any():
                logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
            finished = finished | (next_token == token2idx['<eos>'])
            if finished.all(): break
            sequences = torch.cat([sequences, next_token.unsqueeze(1)], dim=1)

        smiles_list = []
        for seq in sequences:
            tokens_seq = [idx2token[i.item()] for i in seq]
            if '<eos>' in tokens_seq:
                tokens_seq = tokens_seq[:tokens_seq.index('<eos>')]
            smiles = ''.join([t for t in tokens_seq if t not in VOCAB_SPECIAL])
            smiles_list.append(smiles)
        return smiles_list

In [11]:
dataset = ConditionalSMILESDataset(subset_df)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

model = ConditionalTransformerVAE(
    vocab_size=len(token2idx),
    emb_dim=EMB_DIM,
    latent_dim=LATENT_DIM,
    num_heads=N_HEADS,
    ff_dim=FF_DIM,
    num_layers=NUM_LAYERS,
    max_len=MAX_LEN,
    feature_dim=len(NUM_FEATURES)
).to(device)

# чуть меньший lr для дополнительной стабильности
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)



In [None]:
checkpoint_dir = 'checkpoints_VAE'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_conditional.pth')

start_epoch = 0
best_loss = float('inf')
loss_history = []

EPOCHS = 50

for epoch in range(start_epoch, start_epoch + EPOCHS):
    model.train()
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0

    # beta-annealing для KL (плавное нарастание)
    beta = min(1.0, (epoch + 1) / 10.0)

    for tgt_inp, tgt_out, feats in tqdm(loader, desc=f"Epoch {epoch+1}"):
        tgt_inp, tgt_out, feats = tgt_inp.to(device), tgt_out.to(device), feats.to(device)
        optimizer.zero_grad()
        logits, mu, logvar = model(tgt_inp, feats, tgt_inp)

        # защита: если в логитах NaN — заменим или пропустим шаг
        if torch.isnan(logits).any():
            print("NaN in logits detected! Skipping batch.")
            continue

        loss, recon_loss, kl_loss = conditional_vae_loss(logits, tgt_out, mu, logvar, beta=beta)

        if torch.isnan(loss) or torch.isinf(loss):
            print("NaN/Inf detected in loss! Skipping step")
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()

    # безопасная агрегация (если все батчи были пропущены, не делим на 0)
    n_batches = len(loader)
    avg_loss = total_loss / n_batches if n_batches > 0 else float('nan')
    avg_recon = total_recon / n_batches if n_batches > 0 else float('nan')
    avg_kl = total_kl / n_batches if n_batches > 0 else float('nan')
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f} | beta: {beta:.3f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint сохранён на {checkpoint_path} (лучший по Loss)")

Epoch 1:  71%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                    | 221/313 [00:32<00:08, 11.24it/s]

In [None]:
plt.plot(loss_history)
plt.title("Обучение Conditional VAE")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

In [None]:
torch.save(model.state_dict(), "transformer_smiles_conditional.pth")

In [None]:
# Генерация 100 молекул по средним признакам обучающего поднабора
model.eval()
mean_feats = torch.tensor(subset_df[NUM_FEATURES].mean().values, dtype=torch.float32).unsqueeze(0)
num_generate = 100
target_feats = mean_feats.repeat(num_generate, 1)

generated_smiles = generate_smiles_conditional(model, target_feats, num_samples=num_generate, temperature=0.8)
valid_smiles = [s for s in generated_smiles if is_valid_smiles(s)]
print(f"Сгенерировано {len(valid_smiles)} валидных SMILES из {num_generate}")

# Отобразим первые 10 молекул
mols = [Chem.MolFromSmiles(s) for s in valid_smiles[:10]]
if len(mols) > 0:
    img = Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200,200))
    display(img)

# Выводим все сгенерированные валидные SMILES
print("Все сгенерированные валидные SMILES:")
for i, s in enumerate(valid_smiles, 1):
    print(f"{i}: {s}")

# пример задания признаков: максимизируем терм. разложение
example_feats = subset_df[NUM_FEATURES].mean().values
example_feats[NUM_FEATURES.index("Thermal Decomposition Temperature (K)")] = 1.0
example_feats = torch.tensor(example_feats, dtype=torch.float32).unsqueeze(0)
target_feats = example_feats.repeat(num_generate, 1)