In [1]:
from __future__ import annotations
import os
from pathlib import Path
from typing import Iterable, List, Tuple

import numpy as np
import pandas as pd
import torch

from tqdm import tqdm

In [5]:
def load_rnafm(device=None):
    import sys
    sys.path.append(r"C:\Users\User\RNA-FM-main")  # путь к папке, где лежит fm

    from fm import pretrained  # локальный импорт

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model, alphabet = pretrained.rna_fm_t12()
    model = model.to(device).eval()
    batch_converter = alphabet.get_batch_converter()
    return model, batch_converter, device


In [6]:
import sys
from pathlib import Path
import torch
import ptflops
# Путь к корню репозитория RNA-FM (где лежит папка fm)
RNAFM_PATH = Path(r"C:\Users\User\RNA-FM-main")
sys.path.insert(0, str(RNAFM_PATH))

try:
    from fm import pretrained
    print("✅ Пакет 'fm' успешно импортирован!")
except ImportError as e:
    print("❌ Ошибка импорта 'fm':", e)
    raise

# Определяем устройство
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    model, alphabet = pretrained.rna_fm_t12()
    model = model.to(device).eval()
    print(f"✅ Модель RNA-FM загружена и готова! Устройство: {device}")
except Exception as e:
    print("❌ Ошибка при загрузке модели:", e)


✅ Пакет 'fm' успешно импортирован!
✅ Модель RNA-FM загружена и готова! Устройство: cpu


In [7]:
def load_rnafm(device=None):
    import sys
    sys.path.append(r"C:\Users\User\RNA-FM-main")  # путь к папке, где лежит fm

    from fm import pretrained  # локальный импорт

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model, alphabet = pretrained.rna_fm_t12()
    model = model.to(device).eval()
    batch_converter = alphabet.get_batch_converter()
    return model, batch_converter, device


In [8]:
# ==== удобства и стабильность ====
import math, random, os
from typing import List, Optional
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [27]:
df = pd.read_csv(r"C:\Users\User\Desktop\ИТМО\Проект\работа с базой DT_Curated\Data_ML.csv")

In [10]:
# ---- конфиг (оставь свои пути/параметры как есть) ----
DATA_DIR   = Path(r"C:\Users\User\Desktop\ИТМО\Проект\работа с базой DT_Curated")
CSV_NAME   = "Data_ML.csv"
EMB_NPY    = "embeddings.npy"
MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
AMP      = torch.cuda.is_available()
EPOCHS   = 1
BATCH_SIZE = 8         # можно 8 — быстрее чем 2, но всё ещё лёгкий
MAX_LEN  = 64
FREEZE_N = 10
NUM_HEADS = 8
LR       = 2e-5
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
WARMUP_FR = 0.06
SEED     = 42

# ---- фиксация сидов ----
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

In [11]:
# ---- датасет (SMILES + RNA эмбеддинги) ----
class SmilesRnaDataset(Dataset):
    def __init__(self, smiles: List[str], rna_embs: np.ndarray, tokenizer, max_len: int):
        assert len(smiles) == len(rna_embs), f"Разные длины: {len(smiles)} vs {len(rna_embs)}"
        self.smiles = smiles
        self.rna = torch.tensor(rna_embs, dtype=torch.float32)
        self.tok = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.smiles)

    def __getitem__(self, idx):
        s = self.smiles[idx]
        enc = self.tok(
            s,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
            add_special_tokens=True,
        )
        input_ids = enc["input_ids"].squeeze(0)         # (L,)
        attention = enc["attention_mask"].squeeze(0)    # (L,)
        rna = self.rna[idx]                              # (d_rna,)

        # labels = input_ids, но паддинги -> -100
        labels = input_ids.clone()
        labels[labels == self.tok.pad_token_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention,
            "rna": rna,
            "labels": labels,
        }


In [14]:
# ---- модель: ChemBERTa + cross-attn (твоя версия; forward просто возвращает logits) ----
class ChemBERTaCrossAttentionLM(nn.Module):
    def __init__(self, model_name: str, rna_dim: int, freeze_n_layers: int = 10, num_heads: int = 8):
        super().__init__()
        self.chem = AutoModel.from_pretrained(model_name)
        hidden = self.chem.config.hidden_size

        # заморозка первых N слоёв (если есть encoder.layer)
        if hasattr(self.chem, "encoder") and hasattr(self.chem.encoder, "layer"):
            n_layers = len(self.chem.encoder.layer)
            freeze_n = min(freeze_n_layers, n_layers)
            for layer in self.chem.encoder.layer[:freeze_n]:
                for p in layer.parameters():
                    p.requires_grad = False

        self.rna_proj = nn.Linear(rna_dim, hidden)
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(hidden)
        self.ffn = nn.Sequential(
            nn.Linear(hidden, hidden * 4),
            nn.GELU(),
            nn.Linear(hidden * 4, hidden)
        )
        self.lm_head = nn.Linear(hidden, self.chem.config.vocab_size, bias=False)

    def forward(self, input_ids, attention_mask, rna):
        enc = self.chem(input_ids=input_ids, attention_mask=attention_mask)
        h = enc.last_hidden_state                       # (B, L, H)

        # RNA -> (B, 1, H)
        rna_tok = self.rna_proj(rna).unsqueeze(1)

        # cross-attn: Q=h(L токенов), K/V=один RNA токен
        attn_out, _ = self.cross_attn(query=h, key=rna_tok, value=rna_tok)  # (B, L, H)

        # residual + FFN
        h = self.norm(h + attn_out)
        h = h + self.ffn(h)

        logits = self.lm_head(h)                        # (B, L, V)
        return logits

    @torch.no_grad()
    def generate(self,
                 tokenizer,
                 rna: torch.Tensor,
                 max_new_tokens: int = 64,
                 prefix: Optional[str] = None,
                 temperature: float = 1.0,
                 top_k: Optional[int] = None,
                 top_p: Optional[float] = 0.9):
        """
        Автогенерация токенов (encoder-conditioned).
        - корректные attention_mask;
        - ограничение по max_new_tokens;
        - temperature, top_k, top_p (nucleus).
        """
        self.eval()
        device = next(self.parameters()).device

        if tokenizer.pad_token is None:
            # Roberta-подобные часто используют eos как pad
            tokenizer.pad_token = tokenizer.eos_token

        if prefix is None:
            if tokenizer.bos_token_id is not None:
                ids = torch.tensor([[tokenizer.bos_token_id]], device=device)
            else:
                # если BOS нет — начнём с пустого (добавим паддинг-маску)
                ids = torch.tensor([[tokenizer.pad_token_id]], device=device)
        else:
            ids = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=MAX_LEN).input_ids.to(device)

        rna = rna.to(device)  # (B, d_rna) или (1, d_rna)

        for _ in range(max_new_tokens):
            # актуальная маска внимания (паддинг=0)
            attn = (ids != tokenizer.pad_token_id).long()

            logits = self.forward(ids, attn, rna.expand(ids.size(0), -1))
            next_logits = logits[:, -1, :]  # (B, V)

            # temperature
            if temperature and temperature > 0:
                next_logits = next_logits / temperature

            # top-k
            if top_k is not None and top_k > 0:
                k = min(top_k, next_logits.size(-1))
                v, _ = torch.topk(next_logits, k)
                thr = v[:, -1].unsqueeze(-1)
                next_logits = torch.where(next_logits < thr, torch.full_like(next_logits, -1e10), next_logits)

            # top-p (nucleus)
            if top_p is not None and 0 < top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(next_logits, descending=True, dim=-1)
                probs = torch.softmax(sorted_logits, dim=-1)
                cumprobs = torch.cumsum(probs, dim=-1)
                # маска токенов, выходящих за предел вероятностной массы top_p
                cutoff = (cumprobs > top_p)
                # гарантируем, что хотя бы один токен остаётся
                cutoff[..., 0] = False
                sorted_logits[cutoff] = -1e10
                # возвращаем на исходные позиции
                unsorted = torch.full_like(next_logits, -1e10)
                unsorted.scatter_(1, sorted_idx, sorted_logits)
                next_logits = unsorted

            probs = torch.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)  # (B, 1)

            ids = torch.cat([ids, next_id], dim=1)

            # стоп по EOS
            if tokenizer.eos_token_id is not None and (next_id == tokenizer.eos_token_id).all():
                break

            # стоп, если упёрлись в MAX_LEN
            if ids.size(1) >= MAX_LEN:
                break

        return tokenizer.batch_decode(ids, skip_special_tokens=True)[0]


In [15]:
# ==== загрузка данных ====
csv_path = DATA_DIR / CSV_NAME
emb_path = DATA_DIR / EMB_NPY

df_full = pd.read_csv(csv_path)
assert "SMILES" in df_full.columns, "В CSV должна быть колонка 'SMILES'."
rna_embs_full = np.load(emb_path)

if rna_embs_full.shape[0] != len(df_full):
    raise ValueError(f"RNA-эмбеддингов: {rna_embs_full.shape[0]}, строк в CSV: {len(df_full)} — проверь соответствие!")

In [16]:
# ---- делим 80/20, а train дополнительно урезаем до 150 строк ----
df_train, df_val, embs_train, embs_val = train_test_split(
    df_full, rna_embs_full, test_size=0.2, random_state=SEED, shuffle=True, stratify=None
)
if len(df_train) > 150:
    df_train = df_train.iloc[:150].reset_index(drop=True)
    embs_train = embs_train[:150]

print(f"Train: {len(df_train)} | Val: {len(df_val)}")

Train: 150 | Val: 6089


In [17]:
# ---- токенизатор ----
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


In [32]:
from sklearn.model_selection import train_test_split
rna_embs= np.load(r"C:\Users\User\Desktop\ИТМО\Проект\работа с базой DT_Curated\embeddings.npy")
# --- Берём ровно 150 первых примеров (и синхронно режем эмбеддинги) ---
N = min(150, len(df), rna_embs.shape[0])
df_small   = df.iloc[:N].reset_index(drop=True)
rna_small  = rna_embs[:N]
smiles_all = df_small["SMILES"].astype(str)

# --- Индексы для разбиения 80/20 ---
idx = np.arange(N)
train_idx, val_idx = train_test_split(idx, test_size=0.2, random_state=42, shuffle=True)

# --- Формируем списки/массивы под твой SmilesRnaDataset ---
smiles_train = smiles_all.iloc[train_idx].tolist()
smiles_val   = smiles_all.iloc[val_idx].tolist()
rna_train    = rna_small[train_idx]
rna_val      = rna_small[val_idx]

# --- Датасеты ---
train_ds = SmilesRnaDataset(smiles_train, rna_train, tokenizer, MAX_LEN)
val_ds   = SmilesRnaDataset(smiles_val,   rna_val,   tokenizer, MAX_LEN)

# --- DataLoader-ы (Windows-safe: num_workers=0) ---
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f"✅ Разбиение готово: train={len(train_ds)} | val={len(val_ds)}")


✅ Разбиение готово: train=120 | val=30


In [33]:
# ---- модель ----
d_rna = embs_train.shape[1]
model = ChemBERTaCrossAttentionLM(MODEL_NAME, rna_dim=d_rna, freeze_n_layers=FREEZE_N, num_heads=NUM_HEADS).to(DEVICE)

In [20]:
# ---- оптимайзер/шедулер/лоссы ----
params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(params, lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = max(1, EPOCHS * len(train_dl))
warmup_steps = int(WARMUP_FR * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
loss_f = nn.CrossEntropyLoss(ignore_index=-100)
scaler = torch.cuda.amp.GradScaler(enabled=AMP)

  scaler = torch.cuda.amp.GradScaler(enabled=AMP)


In [34]:
# === Обучение (1 эпоха на train_dl) ===
model.train()
running_loss = 0.0
print(f"\n=== Эпоха 1/1 (train {len(train_ds)} примеров) ===")
for step, batch in enumerate(train_dl, 1):
    input_ids = batch["input_ids"].to(DEVICE)
    attention = batch["attention_mask"].to(DEVICE)
    rna       = batch["rna"].to(DEVICE)
    labels    = batch["labels"].to(DEVICE)

    optimizer.zero_grad(set_to_none=True)
    with torch.cuda.amp.autocast(enabled=AMP):
        logits = model(input_ids, attention, rna)
        logits_shift = logits[:, :-1, :].contiguous()
        labels_shift = labels[:, 1:].contiguous()
        loss = loss_f(
            logits_shift.view(-1, logits_shift.size(-1)),
            labels_shift.view(-1)
        )

    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(params, GRAD_CLIP)
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()
    running_loss += loss.item()

    if step % 10 == 0:
        print(f"[{step}/{len(train_dl)}] loss: {loss.item():.4f}")

print(f"📊 Средний train loss: {running_loss / max(1, len(train_dl)):.4f}")

# === Валидация (на val_dl) ===
model.eval()
total_loss = 0
total_tok  = 0
correct    = 0

with torch.no_grad():
    for batch in val_dl:
        input_ids = batch["input_ids"].to(DEVICE)
        attention = batch["attention_mask"].to(DEVICE)
        rna       = batch["rna"].to(DEVICE)
        labels    = batch["labels"].to(DEVICE)

        logits = model(input_ids, attention, rna)
        logits_shift = logits[:, :-1, :].contiguous()
        labels_shift = labels[:, 1:].contiguous()

        mask = labels_shift.ne(-100)
        loss = loss_f(
            logits_shift.view(-1, logits_shift.size(-1)),
            labels_shift.view(-1)
        )

        total_loss += loss.item() * mask.sum().item()
        total_tok  += mask.sum().item()
        preds = logits_shift.argmax(dim=-1)
        correct += ((preds == labels_shift) & mask).sum().item()

val_loss = total_loss / max(1, total_tok)
val_ppl  = float(np.exp(val_loss)) if total_tok > 0 else float("inf")
val_acc  = correct / max(1, total_tok)

print(f"✅ Валид. метрики: Loss={val_loss:.4f} | PPL={val_ppl:.2f} | Acc={val_acc*100:.2f}%")



=== Эпоха 1/1 (train 120 примеров) ===


  with torch.cuda.amp.autocast(enabled=AMP):


[10/15] loss: 6.8863
📊 Средний train loss: 6.8833
✅ Валид. метрики: Loss=6.9463 | PPL=1039.25 | Acc=0.00%


In [35]:
# ==== сохранение ====
out_dir = DATA_DIR / "chemberta_crossattn_rna"
out_dir.mkdir(exist_ok=True)
# сохраним веса и токенизатор в формате HF
torch.save(model.state_dict(), out_dir / "pytorch_model.bin")
tokenizer.save_pretrained(out_dir)
print(f"\n💾 Сохранено в: {out_dir}")



💾 Сохранено в: C:\Users\User\Desktop\ИТМО\Проект\работа с базой DT_Curated\chemberta_crossattn_rna


In [36]:
def generate(self,
             tokenizer,
             rna: torch.Tensor,
             max_new_tokens: int = 64,
             prefix: str = None,
             temperature: float = 1.0,
             top_k: int = None,
             top_p: float = None):
    self.eval()
    device = next(self.parameters()).device

    # BOS/префикс
    if prefix is None:
        bos_id = tokenizer.bos_token_id
        if bos_id is None:
            # для Roberta-подобных часто <s> как bos
            bos_id = tokenizer.convert_tokens_to_ids("<s>") if "<s>" in tokenizer.get_vocab() else tokenizer.cls_token_id
        ids = torch.tensor([[bos_id]], device=device, dtype=torch.long)
    else:
        enc = tokenizer(prefix, return_tensors="pt", truncation=True, max_length=MAX_LEN, add_special_tokens=True)
        ids = enc["input_ids"].to(device)

    rna = rna.to(device)

    with torch.no_grad():
        for _ in range(max_new_tokens):
            attn_mask = (ids != (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)).long()
            logits = self.forward(ids, attn_mask, rna.expand(ids.size(0), -1))
            next_logits = logits[:, -1, :] / max(1e-6, temperature)

            # top-k
            if top_k is not None and top_k > 0:
                k = min(top_k, next_logits.size(-1))
                v, _ = torch.topk(next_logits, k)
                cut = v[:, [-1]]
                next_logits[next_logits < cut] = -float("inf")

            # top-p (nucleus)
            if top_p is not None and 0 < top_p < 1:
                sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
                probs = torch.softmax(sorted_logits, dim=-1)
                cumprobs = torch.cumsum(probs, dim=-1)
                # маскируем всё после порога p
                sorted_logits[cumprobs > top_p] = -float("inf")
                # возвращаем в исходный порядок
                next_logits = torch.zeros_like(next_logits).scatter(1, sorted_idx, sorted_logits)

            probs = torch.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)  # (B,1)

            ids = torch.cat([ids, next_id], dim=1)

            # остановка по EOS
            if tokenizer.eos_token_id is not None and (next_id == tokenizer.eos_token_id).all():
                break

    return tokenizer.batch_decode(ids, skip_special_tokens=True)[0]


In [37]:
# ==== валидация (метрики) ====
def evaluate(model, dataloader, loss_f):
    model.eval()
    total_loss = 0
    total_tokens = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention = batch["attention_mask"].to(DEVICE)
            rna       = batch["rna"].to(DEVICE)
            labels    = batch["labels"].to(DEVICE)

            logits = model(input_ids, attention, rna)
            logits_shift = logits[:, :-1, :].contiguous()
            labels_shift = labels[:, 1:].contiguous()

            loss = loss_f(logits_shift.view(-1, logits_shift.size(-1)),
                          labels_shift.view(-1))

            mask = labels_shift != -100
            n_tok = mask.sum().item()
            total_loss += loss.item() * n_tok
            total_tokens += n_tok

            preds = logits_shift.argmax(dim=-1)
            correct += ((preds == labels_shift) & mask).sum().item()

    avg_loss = total_loss / max(1, total_tokens)
    ppl = math.exp(avg_loss) if avg_loss < 50 else float("inf")
    acc = correct / max(1, total_tokens)
    return avg_loss, ppl, acc

val_loss, val_ppl, val_acc = evaluate(model, val_dl, loss_f)
print(f"\n📈 Validation Loss: {val_loss:.4f}")
print(f"📉 Perplexity: {val_ppl:.4f}")
print(f"🎯 Token Accuracy: {val_acc*100:.2f}%")


📈 Validation Loss: 6.9463
📉 Perplexity: 1039.2538
🎯 Token Accuracy: 0.00%


In [38]:
# ==== быстрая генерация для проверки ====
model.eval()
rna_t = torch.tensor(embs_val[0], dtype=torch.float32, device=DEVICE).unsqueeze(0)
sample = model.generate(
    tokenizer,
    rna=rna_t,
    max_new_tokens=64,
    prefix=None,
    temperature=1.0,
    top_k=50,
    top_p=0.95
)
print("\n🧪 Пробная генерация:")
print(sample)


🧪 Пробная генерация:
CCCCCC~�ccncFCCCCNCcncscSHCCCCOcnncCCOCCOC��snSHSH�3CCCOCCCCOCCNC)/ncscCCOCCOCt�RCCONC�CCOcjCOP�OCCCnCSCCCNCOCCnCSCCOOCCCnTCSCCOC�CCOCCNCCCOcSHccccnRccoc�TNCCCCCC��OCCCNCccccn��43ClCCBrCCNCCCCCCCSCCOC�ccoc
