In [1]:
import sys 
import torch
sys.path.insert(1, 'MolT5/baselines')

import numpy as np
import dataloader
import json

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.cluster import MiniBatchKMeans
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:


dummy_emb = torch.rand(1000, 20, 20, 64).to('cuda')
pairs = dummy_emb

In [3]:
pairs = torch.load('../burov/unimol_embeddings_project/geom_train_pair_embeddings.pt', weights_only=True)

In [5]:
N = len(pairs)

total = N
n_chunks = 10
chunk_size = total // n_chunks       # → 43647
last_chunk_size = total - chunk_size * (n_chunks - 1)  # → остаток в последнем чанке


In [9]:
chunk_size = 43647
for i in range(0, N, chunk_size):
    chunk = pairs[i:i+chunk_size]
    torch.save(chunk, f"/home/user12/ebeddings/geom_train_pair_embeddings{i//chunk_size:02d}.pt")

In [4]:
import torch
from torch.utils.data import Dataset
import glob

class ChunkedDataset(Dataset):
    def __init__(self, chunk_paths, sizes):
        self.chunk_paths = chunk_paths
        self.sizes = sizes
        self.cumulative = [0]
        for s in sizes:
            self.cumulative.append(self.cumulative[-1] + s)

        self.current_chunk = None
        self.current_chunk_idx = -1

    def __len__(self):
        return self.cumulative[-1]

    def __getitem__(self, idx):
        # Определяем, к какому чанку принадлежит idx
        for i in range(len(self.sizes)):
            if self.cumulative[i] <= idx < self.cumulative[i+1]:
                local_idx = idx - self.cumulative[i]
                if self.current_chunk_idx != i:
                    self.current_chunk = torch.load(self.chunk_paths[i])
                    self.current_chunk_idx = i
                return self.current_chunk[local_idx]

        raise IndexError("Index out of bounds")

# Пример инициализации
chunk_paths = sorted(glob.glob("/home/user12/ebeddings/*.pt"))
sizes = [43647] * 9 + [43648]
dataset = ChunkedDataset(chunk_paths, sizes)


In [9]:
dataset[4]

tensor([[[  1.8952,   4.1560,   6.0471,  ...,  -0.2641,  -0.8994, -11.4943],
         [  0.6613,   2.8653,  -1.2060,  ...,   2.6446,  -2.3022,  -2.1334],
         [  1.1383,   3.5325,  -0.7498,  ...,   3.1662,  -1.1985,  -2.5246],
         ...,
         [  1.0921,   4.1979,   2.5114,  ...,   4.0741,   1.4308,  -4.2720],
         [  1.1059,   3.2943,   4.2010,  ...,   3.0813,   1.7077,  -6.5720],
         [  0.5109,   1.9762,  -1.9728,  ...,   3.1906,  -3.1862,  -1.2246]],

        [[  2.8785,   4.1167,   3.8345,  ...,   0.5797,  -0.9223, -10.1936],
         [  3.2859,  -1.0462,   5.7178,  ...,  -4.2926, -10.8039, -13.7466],
         [  3.7274,  -4.3310,   5.0607,  ...,  -5.4310,  -4.8540, -16.0476],
         ...,
         [  1.5383,   0.1308,   3.3827,  ...,   3.9154,   0.6600,  -6.5256],
         [  1.3351,   3.0836,   4.1798,  ...,   2.4447,   2.9840,  -7.9191],
         [  2.8457,  -0.5740,   4.2678,  ...,   2.6725,  -0.9782,  -5.0900]],

        [[  2.4847,   2.8307,   4.9040,  ...

In [None]:
torch.save('~/geom_train_pair_embeddings.pt')

In [7]:
geom = torch.load('./datasets/geom_train.pt', weights_only=True)

In [3]:
geom = torch.load('./datasets/geom_cut.pt', weights_only=True)

In [None]:


k         = 128
batchsize = 10_000           # сколько 64-векторов обрабатываем за раз
max_pairs = 10_000_000          # всего примеров, на которых «доведём» центры

kmeans = MiniBatchKMeans(n_clusters=k,
                         batch_size=batchsize,
                         init_size=k*3,        # можно побольше для устойчивости
                         verbose=0,
                         random_state=42)

seen = 0
for mol in tqdm(pairs):                   # pair_list: list[T(N,N,64)]
    vecs = mol.reshape(-1, 64)                # (N²,64)   – в gpu/cpu памяти молекулы
    # --- случайно берём не больше batchsize векторов ---
    if vecs.size(0) > batchsize:
        idx = torch.randperm(vecs.size(0))[:batchsize]
        vecs = vecs[idx]

    kmeans.partial_fit(vecs.cpu().numpy())    # учим на CPU, по кусочкам
    seen += vecs.size(0)
    if seen >= max_pairs:                     # хватит примеров – выходим
        break

centers = torch.tensor(kmeans.cluster_centers_)   # (128,64)  • готово


def vec2tok(v: torch.Tensor) -> str:
    idx = torch.cdist(v[None], centers).argmin().item()
    return f"<p{idx:03d}>"


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 1000/1000 [00:02<00:00, 409.94it/s]


In [8]:
def vec2tok(v: torch.Tensor) -> str | list[str]:
    """
    v : Tensor (64,)       → '<p042>'
        Tensor (M,64)      → ['<p042>', '<p118>', ...] длиной M
    """
    v = v.to(centers)                       # убедимся, что на том же девайсе

    if v.ndim == 1:                         # одиночный вектор
        idx = torch.cdist(v[None], centers).argmin().item()
        return f"<p{idx:03d}>"

    # батч M×64  → M индексов
    idx = torch.cdist(v, centers).argmin(dim=1).tolist()   # List[int] длиной M
    return [f"<p{i:03d}>" for i in idx]                    # List[str]


In [6]:
vecs = torch.randn(40, 64)
print(vec2tok(vecs)[:4])
# ['<p042>', '<p118>', '<p031>', '<p007>']


['<p061>', '<p047>', '<p094>', '<p102>']


In [None]:
model_name = "Qwen/Qwen1.5-1.8B-Chat"   # или другая
tokenizer  = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
model      = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

# 128 новых «pair-токенов»
new_tokens = [f"<p{i:03d}>" for i in range(128)]

num_added = tokenizer.add_tokens(new_tokens, special_tokens=False)
print("Добавили:", num_added)      # должно быть 128
model.resize_token_embeddings(len(tokenizer))   # 🔑 расширяем эмбеддинг



Добавили: 128


Embedding(151774, 2048)

In [12]:
tokenizer.save_pretrained("tokenizer_pair128")

('tokenizer_pair128/tokenizer_config.json',
 'tokenizer_pair128/special_tokens_map.json',
 'tokenizer_pair128/chat_template.jinja',
 'tokenizer_pair128/vocab.json',
 'tokenizer_pair128/merges.txt',
 'tokenizer_pair128/added_tokens.json',
 'tokenizer_pair128/tokenizer.json')

In [9]:
len(geom[0]['positions'])
pairs = []
for mol in tqdm(geom):
    pair = torch.rand(len(mol['positions']), len(mol['positions']), 64)
    pair.to('cuda')
    pairs.append(pair)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:02<00:00, 3379.26it/s]


In [10]:
import json
from tqdm import tqdm

SYSTEM = ("You are a chemist. For each atom pair within 2 Å classify the "
          "bond type. Labels: 0 no-bond, 1 single, 2 double, 3 triple, 4 aromatic. "
          "Return ONLY JSON list [{\"pair\":[i,j],\"label\":n}].")

ORDER2LABEL = {0:1, 1:2, 2:3, 3:4, 4:1, 5:1, 6:1, 7:1, 8:1, 9:1}
ALLOWED = ['C', 'O', 'N', 'F', 'S', 'Cl', 'Br', 'I', 'P']

dev = 'cuda'
centers_gpu = centers.to(dev)
with open("bond_prompts.jsonl", "w") as fout:
    
    for mol, pair in tqdm(zip(geom, pairs), total=1000):
        xyz   = mol["positions"].to(dev, non_blocking=True)      # (N,3)
        types = mol["one_hot"].argmax(-1)                        # остаётся на CPU
        pair  = pair.to(dev, non_blocking=True)                  # (N,N,64)

        dmat  = torch.cdist(xyz, xyz)                            # GPU, fp32
        i_idx, j_idx = (dmat<=2.0).nonzero(as_tuple=True)        # тоже GPU
        
        if not len(i_idx):  continue

        vecs   = pair[i_idx, j_idx]                              # (M,64)_gpu
        idx    = torch.cdist(vecs, centers_gpu).argmin(dim=1)        # (M,)_gpu
        toks   = [f"<p{i:03d}>" for i in idx.tolist()]                            # List[str]  (исправленный!)

        lines = [f"[{i},{j}]={tok}"
                for (i,j),tok in zip(zip(i_idx.tolist(), j_idx.tolist()), toks)]

        # ---------- метки ------------------------------------------------------
        edge2order = {(int(u),int(v)): bo.nonzero(as_tuple=True)[0].item()
                    for (u,v), bo in zip(mol["edge_index"], mol["bond_orders"])}
        edge2order |= {(v,u):o for (u,v),o in edge2order.items()}

        labels = [0 if (i,j) not in edge2order else ORDER2LABEL[edge2order[(i,j)]]
                for i,j in zip(i_idx.tolist(), j_idx.tolist())]

        assistant = json.dumps([{"pair":[i,j],"label":l}
                                for (i,j),l in zip(zip(i_idx.tolist(), j_idx.tolist()),
                                                labels)],
                            separators=(",",":"))

        atom_line = "Atoms (index→type): " + \
                    ", ".join(f"{idx}:{ALLOWED[types[idx]]}"
                            for idx in range(len(types)))

        prompt = {"messages":[
            {"role":"system","content":SYSTEM},
            {"role":"user",
            "content":"Pairs within 2 Å:\n" + "\n".join(lines) + "\n\n" + atom_line},
            {"role":"assistant","content":assistant}
        ]}
        fout.write(json.dumps(prompt)+"\n")



2535it [01:17, 32.73it/s]                         


KeyboardInterrupt: 

In [50]:
print('{"messages": [{"role": "system", "content": "You are a chemist. For each atom pair within 2 \u00c5 classify the bond type. Labels: 0 no-bond, 1 single, 2 double, 3 triple, 4 aromatic. Return ONLY JSON list [{\"pair\":[i,j],\"label\":n}]."}, {"role": "user", "content": "Pairs within 2 \u00c5:\n[0,1]=<p055>\n[1,0]=<p103>\n[1,2]=<p072>\n[2,1]=<p010>\n[2,3]=<p010>\n[3,2]=<p072>\n[3,4]=<p102>\n[3,5]=<p084>\n[3,8]=<p080>\n[4,3]=<p102>\n[4,5]=<p102>\n[4,7]=<p076>\n[5,3]=<p084>\n[5,4]=<p102>\n[5,6]=<p016>\n[6,5]=<p040>\n[6,7]=<p040>\n[7,4]=<p076>\n[7,6]=<p016>\n[7,8]=<p010>\n[8,3]=<p080>\n[8,7]=<p010>"}, {"role": "assistant", "content": "[{\"pair\":[0,1],\"label\":0},{\"pair\":[1,0],\"label\":0},{\"pair\":[1,2],\"label\":0},{\"pair\":[2,1],\"label\":0},{\"pair\":[2,3],\"label\":0},{\"pair\":[3,2],\"label\":0},{\"pair\":[3,4],\"label\":0},{\"pair\":[3,5],\"label\":0},{\"pair\":[3,8],\"label\":0},{\"pair\":[4,3],\"label\":0},{\"pair\":[4,5],\"label\":0},{\"pair\":[4,7],\"label\":0},{\"pair\":[5,3],\"label\":0},{\"pair\":[5,4],\"label\":0},{\"pair\":[5,6],\"label\":0},{\"pair\":[6,5],\"label\":0},{\"pair\":[6,7],\"label\":0},{\"pair\":[7,4],\"label\":0},{\"pair\":[7,6],\"label\":0},{\"pair\":[7,8],\"label\":0},{\"pair\":[8,3],\"label\":0},{\"pair\":[8,7],\"label\":0}]"}]}')

{"messages": [{"role": "system", "content": "You are a chemist. For each atom pair within 2 Å classify the bond type. Labels: 0 no-bond, 1 single, 2 double, 3 triple, 4 aromatic. Return ONLY JSON list [{"pair":[i,j],"label":n}]."}, {"role": "user", "content": "Pairs within 2 Å:
[0,1]=<p055>
[1,0]=<p103>
[1,2]=<p072>
[2,1]=<p010>
[2,3]=<p010>
[3,2]=<p072>
[3,4]=<p102>
[3,5]=<p084>
[3,8]=<p080>
[4,3]=<p102>
[4,5]=<p102>
[4,7]=<p076>
[5,3]=<p084>
[5,4]=<p102>
[5,6]=<p016>
[6,5]=<p040>
[6,7]=<p040>
[7,4]=<p076>
[7,6]=<p016>
[7,8]=<p010>
[8,3]=<p080>
[8,7]=<p010>"}, {"role": "assistant", "content": "[{"pair":[0,1],"label":0},{"pair":[1,0],"label":0},{"pair":[1,2],"label":0},{"pair":[2,1],"label":0},{"pair":[2,3],"label":0},{"pair":[3,2],"label":0},{"pair":[3,4],"label":0},{"pair":[3,5],"label":0},{"pair":[3,8],"label":0},{"pair":[4,3],"label":0},{"pair":[4,5],"label":0},{"pair":[4,7],"label":0},{"pair":[5,3],"label":0},{"pair":[5,4],"label":0},{"pair":[5,6],"label":0},{"pair":[6,5],"label":

In [29]:
# убедитесь, что vec2tok() возвращает СТРОКУ целиком
def vec2tok(v):
    idx = torch.cdist(v.unsqueeze(0), centers).argmin().item()
    return f"<p{idx:03d}>"           # → '<p042>'

# формируем lines без лишних \n и join
lines = [f"[{int(i)},{int(j)}]={tok}"      # tok уже готов '<p042>'
         for (i, j), tok in zip(zip(i_idx, j_idx), toks)]

user_text = "Pairs within 2 Å:\n" + "\n".join(lines)
user_text

'Pairs within 2 Å:\n[0,1]=<\n[1,0]=p\n[1,2]=2\n[1,5]=2\n[2,1]=9\n[2,3]=2\n[3,2]=>'

In [None]:
pairs_ij, vecs = ... # ≤2 Å выборка\n
lines = [f"[{i},{j}]={emb2tok(vecs[k])}"
for k,(i,j) in enumerate(pairs_ij)]
user = "Pairs within 2 Å:\\n" + "\\n".join(lines)

In [2]:
import random

line = random.choice(open("bond_prompts.jsonl").readlines())

In [5]:
json.loads(line)

{'messages': [{'role': 'system',
   'content': 'You are a chemist. For each atom pair within 2 Å classify the bond type. Labels: 0 no-bond, 1 single, 2 double, 3 triple, 4 aromatic. Return ONLY JSON list [{"pair":[i,j],"label":n}].'},
  {'role': 'user',
   'content': 'Pairs within 2 Å:\n[0,0]=<p079>\n[0,1]=<p043>\n[1,0]=<p036>\n[1,1]=<p117>\n[1,2]=<p010>\n[1,3]=<p045>\n[1,8]=<p106>\n[2,1]=<p010>\n[2,2]=<p117>\n[2,3]=<p010>\n[2,4]=<p010>\n[2,5]=<p045>\n[3,1]=<p045>\n[3,2]=<p010>\n[3,3]=<p117>\n[4,2]=<p010>\n[4,4]=<p117>\n[5,2]=<p045>\n[5,5]=<p025>\n[5,6]=<p045>\n[6,5]=<p028>\n[6,6]=<p117>\n[6,7]=<p124>\n[7,6]=<p124>\n[7,7]=<p117>\n[8,1]=<p106>\n[8,8]=<p117>\n\nAtoms (index→type): 0:C, 1:C, 2:C, 3:C, 4:O, 5:C, 6:C, 7:C, 8:C'},
  {'role': 'assistant',
   'content': '[{"pair":[0,0],"label":0},{"pair":[0,1],"label":1},{"pair":[1,0],"label":1},{"pair":[1,1],"label":0},{"pair":[1,2],"label":1},{"pair":[1,3],"label":1},{"pair":[1,8],"label":1},{"pair":[2,1],"label":1},{"pair":[2,2],"label":0},