In [18]:
# part 1: 导入相关的 package
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass

from rdkit import Chem
from rdkit.Chem import BRICS
from rdkit.Chem import Draw
from rdkit.Chem import AllChem, DataStructs

import os
import copy
import random
import datetime
import numpy as np
import math
import ast

seed = 888

def set_random_seed(seed=88):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_random_seed(seed)

# 获取当前时间
current_time = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")

if torch.cuda.is_available():
    print('use GPU')
    device = torch.device('cuda')
else:
    print('use CPU')
    device = torch.device('cpu')


use GPU


In [19]:
@dataclass
class GPTConfig:
    block_size: int = 64   # 这里其实应该是文本的最大长度（ max_seq_len）
    seq_len = block_size
    batch_size: int = 512
    n_layer: int = 6
    n_head: int = 16
    num_heads = n_head
    n_embd: int = 256  # n_embd 也叫 hidden_dim, hiden_size, 这里我同时设置了和 embed_dim 一样
    d_model = n_embd    
    v_head_dim = 32
    kv_lora_rank = 16
    q_lora_rank = 3 * kv_lora_rank
    rope_head_dim = 32
    nope_head_dim = 16
    dropout: float = 0.1
    # # tiktoken 使用的是 GPT-2 的词表，大约有 50257 个token
    vocab_size: int = 67

In [20]:
datasmi_mol_elem = []
datasmi_mol_2nd = []
datasmi_mol_total = []
data_mol_elem = []
data_mol_2nd = []
data_mol_total = []
data_1_path = './mol_in_chem_space.txt'
data_2_path = './mol_not_in_chem_space.txt'


elem_frag_smi_list = []
frag_smi_to_2nd_break_list = []
frag_smi_2nd_break_dict = {}


with open("./my_frags_3.txt", "r") as f:
    frag_smi_to_2nd_break_list = [line.split(':')[0].strip() for line in f.readlines()[:109] if ':' in line]

with open("./my_frags_3.txt", "r") as f:
    elem_frag_smi_list = [line.strip() for line in f.readlines()[:109] if ':' not in line]
    elem_frag_smi_list.append('xxxx')   # 65 个片段库  + EOS
    elem_frag_smi_list.append('EOS')

print(frag_smi_to_2nd_break_list)
print(len(frag_smi_to_2nd_break_list))
print(elem_frag_smi_list)
print(len(elem_frag_smi_list))

def read_txt_to_2nd_break_dict(txt_file):
    with open(txt_file, "r") as f:
        for line in f.readlines()[:109]:
            if ':' in line:
                frag_smi_2nd_break_dict[line.split(':')[0].strip()] = [ s.strip() for s in line.split(':')[-1].strip().split(',')]
    return frag_smi_2nd_break_dict

frag_smi_2nd_break_dict = read_txt_to_2nd_break_dict("./my_frags_3.txt")
print(frag_smi_2nd_break_dict)
print(len(frag_smi_2nd_break_dict))

# 创建分子片段到索引的映射
fragment_to_idx = {fragment: idx for idx, fragment in enumerate(elem_frag_smi_list)}
idx_to_fragment = {idx: fragment for fragment, idx in fragment_to_idx.items()}
idx_to_fragment.update({66: 'EOS'})

with open('./mol_in_chem_space.txt', 'r') as f:
    for line in f.readlines():
        datasmi_mol_elem.append(ast.literal_eval(line.split(':')[2].strip()))
        

with open('./mol_not_in_chem_space.txt', 'r') as f:
    for line in f.readlines():        
        # 将字符串按 ':' 分割为列表
        parts = line.split(':')
        # 去掉前两个字段，保留后面的部分，并重新组合为一个字符串
        result = ':'.join(parts[2:])        
        # word_freq_vector = word_freq_func(ast.literal_eval(result))
        # data_mol_2nd.append(word_freq_vector)
        list = []
        for i in result.strip()[1:-1].split(','):
            list.append(i.strip()[1:-1])
        datasmi_mol_2nd.append(list)

datasmi_mol_total = datasmi_mol_elem + datasmi_mol_2nd  

['CO', 'CCC', 'CC=O', 'O=CO', 'CN', 'Oc1ccccc1', 'CNC', 'CCC=O', 'CC(=O)O', 'CCO', 'CCCC', 'CCC(=O)O', 'Cc1ccccc1', 'CC(C)O', 'CC(N)C(=O)O', 'CCCCC', 'CN1CCNCC1', 'O=CCO', 'CCCC=O', 'CCN', 'NS(=O)(=O)c1ccccc1', 'Nc1ccccc1', 'CS', 'CCCO', 'Cc1cccc(C)c1', 'CCCC(=O)O', 'CCCCC(=O)O', 'CCCCCC', 'O=C(O)CCCC(=O)O', 'CC(O)CO', 'CN1CCCCC1', 'CC(N)C=O', 'O=CNO', 'Cc1c[nH]c(=O)[nH]c1=O', 'CC(F)(F)F', 'NCC=O', 'CC(C)CCC=O', 'Oc1cccc(O)c1', 'Nc1ccc([SH](=O)=O)cc1', 'CC(C)C=O', 'CCCCC=O', 'O=CCCC(=O)O', 'Cn1cccn1', 'CCC(C)C', 'NC(CCC=O)C(=O)O']
45
['N', 'c1ccccc1', 'O', 'C', 'CC', 'C=O', 'c1ccncc1', 'Clc1ccccc1', 'C1CCNCC1', 'Fc1ccccc1', 'O=P(O)(O)O', 'FC(F)F', 'CC(C)C', 'NC=O', 'C1CCNC1', 'OC1COCC1O', 'C1CNCCN1', 'C1CC1', 'S', 'C1CCCCC1', 'c1cncnc1', 'C1COCCN1', 'OC1COCC(O)C1O', 'c1cn[nH]c1', 'c1c[nH]cn1', 'c1ccc2[nH]ccc2c1', 'c1ccsc1', 'c1ccc2ccccc2c1', 'O=[SH](=O)c1ccccc1', 'Nc1ncnc2[nH]cnc12', 'Clc1cccc(Cl)c1', 'O=[N+]([O-])c1ccccc1', 'OC1CCOC1', 'C1CCCC1', 'O=P(O[3*])(O)OP(=O)(O)O[3*]', 'Fc1ccc

In [21]:
# 自定义 Dataset 类，支持多个分子片段序列
class MoleculeDataset(Dataset):
    def __init__(self, molecule_sequences, fragment_to_idx):
        """
        molecule_sequences: 分子片段序列的列表，每个元素是一个由多个分子片段组成的序列
        fragment_to_idx: 分子片段到索引的映射
        """
        self.molecule_sequences = molecule_sequences
        self.fragment_to_idx = fragment_to_idx
        self.block_size = GPTConfig().block_size
        self.encoded_sequences = []

        # 将分子片段序列转化为对应的索引序列
        for sequence in self.molecule_sequences:
            if len(sequence) == 0:
                continue
            else:
                encoded_sequence = []
                for frag in sequence:
                    if 'xxxx' in frag:
                        encoded_sequence.append(self.fragment_to_idx['xxxx'])
                    else:
                        encoded_sequence.append(self.fragment_to_idx[frag])
                encoded_sequence.append(66)  # EOS = 66
                self.encoded_sequences.extend(encoded_sequence)

        # 将超长文本分割成训练样本
        self.encoded_data = []
        for i in range(0, len(self.encoded_sequences)):
            chunk = self.encoded_sequences[i:i+self.block_size+1]
            # 如果长度不够，用 eos_token 填充
            if len(chunk) < self.block_size + 1:
                chunk = chunk + [66] * (self.block_size + 1 - len(chunk))  # EOS = 66
            self.encoded_data.append(chunk)


    def __len__(self):
        return len(self.encoded_data)
    
    def __getitem__(self, idx):
        chunk = self.encoded_data[idx]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y
        


# 创建 Dataset 和 DataLoader 实例
dataset = MoleculeDataset(datasmi_mol_total, fragment_to_idx)
print(len(dataset))

# split traindataset to train and val
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

train_loader = DataLoader(train_dataset, batch_size=GPTConfig().batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=GPTConfig().batch_size, shuffle=False)

63714


### RMS norm, ROPE, MLA

In [22]:
from ohara.embedings_pos.rotatry import precompute_freqs_cis
from ohara.embedings_pos.rotatry import apply_rope
from ohara.modules.norm import RMSNorm

class MultiHeadLatentAttention(nn.Module):
    """
    Multi Head Latent Attention 
    paper: https://arxiv.org/pdf/2405.04434
    
    TLDR: 
    kv are low ranks, this verient of attention project q,k,v to low rank to save memory,
    replace linear with lora(ish) layers

    source: https://github.com/joey00072/Multi-Head-Latent-Attention-MLA-
    """
    def __init__(self, config):
        super().__init__()
        
        assert config.v_head_dim is not None , f"v_head_dim is not defined {config.v_head_dim=}"
        assert config.q_lora_rank is not None , f"q_lora_rank is not defined {config.q_lora_rank=}"
        assert config.kv_lora_rank is not None , f"kv_lora_rank is not defined {config.kv_lora_rank=}"
        assert config.rope_head_dim is not None , f"rope_head_dim is not defined {config.rope_head_dim=}"
        
        self.config = config
        
        self.dim = config.d_model
        self.num_heads = config.num_heads
        self.v_head_dim = config.v_head_dim
        
        self.nope_head_dim = config.nope_head_dim
        self.rope_head_dim = config.rope_head_dim
        
        self.q_lora_rank = config.q_lora_rank
        self.kv_lora_rank = config.kv_lora_rank
        
        self.dropout = config.dropout
        
        # note: head dim of query and key if different from head dim of value
        
        # (attention_dim == num_head*head_dim) > d_model in deepseekv2
        # this is dim between wV and wQ
        self.value_dim = self.num_heads * self.v_head_dim
        
        # this is dims between wQ and wK
        self.nope_dim = self.num_heads * self.nope_head_dim
        self.rope_dim = self.num_heads * self.rope_head_dim  
        
        # query compression
        self.compress_q_linear = nn.Linear(self.dim, self.q_lora_rank, bias=False)  # W_DQ
        self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_dim, bias=False)
        self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_dim, bias=False)
        self.q_norm = RMSNorm(dim=self.q_lora_rank)
        
        
        # key and value compression
        self.compress_kv_linear = nn.Linear(self.dim, self.kv_lora_rank, bias=False)  # W_DKV
        self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_dim, bias=False)
        self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.value_dim, bias=False)
        self.kv_norm = RMSNorm(dim=self.kv_lora_rank)
        
        
        self.k_rope_linear = nn.Linear(self.dim, self.rope_head_dim  , bias=False)
        # self.rope_norm = RMSNorm(self.rope_dim) # not in deepseekv2

        self.proj = nn.Linear(self.value_dim , self.dim, bias=False)
        self.res_dropout = nn.Dropout(p=config.dropout)
        self.freqs_cis = precompute_freqs_cis(config.rope_head_dim, config.seq_len)
        self.freqs_cis = (self.freqs_cis[0].to(device), self.freqs_cis[0].to(device))
        
        
        
    def forward(self, x: Tensor):        
        batch_size, seq_len, _ = x.shape
        # print(f'batch_size: {batch_size}')

        # 随机生成一个mask用于遮掩部分位置，避免计算某些位置的注意力
        mask = torch.tril(torch.ones(batch_size, seq_len, seq_len))  # 上三角为0，下三角为1，用于 causal mask


        compressed_q = self.compress_q_linear(x)
        norm_q = self.q_norm(compressed_q)
        query_nope: Tensor = self.decompress_q_nope(norm_q)
        query_rope: Tensor = self.decompress_q_rope(norm_q)

        compressed_kv = self.compress_kv_linear(x)
        norm_kv = self.kv_norm(compressed_kv)
        key_nope: Tensor = self.decompress_k_nope(norm_kv)
        value: Tensor = self.decompress_v_linear(norm_kv)

        key_rope: Tensor = self.k_rope_linear(x)

        query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1, 2)
        query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2)

        key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1, 2)
        key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1, 2)

        value = value.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)

        # *** the line that fixes MLA :) ***
        key_rope = key_rope / self.num_heads

        q_rope, k_rope = apply_rope(query_rope, key_rope, cis=self.freqs_cis)

        q_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device)
        k_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device)

        q_recombined[:, :, :, :self.nope_head_dim] = query_nope
        q_recombined[:, :, :, self.nope_head_dim:] = q_rope

        k_recombined[:, :, :, :self.nope_head_dim] = key_nope
        k_recombined[:, :, :, self.nope_head_dim:] = k_rope

        # Apply the mask here. Mask should be of shape (batch_size, 1, seq_len, seq_len) to match the attention logits
        if mask is not None:
            mask = mask.unsqueeze(1)  # Add the head dimension to the mask (batch_size, 1, seq_len, seq_len)
            # The mask should be applied to the attention logits (query_key matrix), with a large negative value (e.g., -1e9)
            # to prevent attention to those positions
            mask = mask.expand(-1, self.num_heads, -1, -1)
            mask = mask.to(x.device)
            # print(f'q_recombined: {q_recombined.shape}')
            # print(f'k_recombined: {k_recombined.shape}')
            # print(f'value: {value.shape}')
            # print(f'mask: {mask.shape}')
            output = F.scaled_dot_product_attention(
                q_recombined, k_recombined, value, attn_mask=mask, is_causal=True, dropout_p=self.dropout
            )
        else:
            # If no mask is provided, just perform the attention as usual
            output = F.scaled_dot_product_attention(
                q_recombined, k_recombined, value, is_causal=True, dropout_p=self.dropout
            )

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_head_dim)

        output = self.proj(output)
        output = self.res_dropout(output)
        return output


### 模型结构

In [23]:
class SingleHeadAttention(nn.Module):
    # 单头注意力机制
    def __init__(self, config):
        super().__init__()
        self.head_size = config.head_size
        self.key = nn.Linear(config.n_embd, config.head_size)
        self.value = nn.Linear(config.n_embd, config.head_size)
        self.query = nn.Linear(config.n_embd, config.head_size)

        # 尝试学习新的写法，attention_mask 通过 register_buffer 注册
        # 因为不用计算 梯度，所以节约内存和显存，速度也更快
        self.register_buffer(
            'attention_mask', 
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            ))
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.size()
        k = self.key(x)
        v = self.value(x)
        q = self.query(x)
        weight = q @ k.transpose(-2, -1)   # @ 就是 torch.matmul 的简化写法
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0, 
            float('-inf')
        )
        weight = F.softmax(weight, dim=-1) / math.sqrt(self.head_size)
        weight = self.dropout(weight)
        out = weight @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                SingleHeadAttention(config)
                for _ in range(config.n_head)
            ]
        )
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads], 
            dim=-1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output


class FeedForward(nn.Module):
    # 实际上就是 MLP
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout)
        )
    
    def forward(self, x):
        return self.net(x)


# 接下来就是一个完整的 Block

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        head_size = config.n_embd // config.n_head
        # self.att = MultiHeadAttention(config)
        self.att = MultiHeadLatentAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

from torch.utils.tensorboard import SummaryWriter
# 使用 TensorBoard 记录训练信息
writer = SummaryWriter(f'tensorboard_run_nn_models/frag_recomd_mini_QWEN_{current_time}')  # 创建 SummaryWriter 实例


# 以后会讲  MLA ,  MOE, DPO 完全手写
# 完整的  GPT model
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # TODO position embedding --> ROPE
        # TODO layer norm --> RMS norm
        # TODO MLP --> swiglu
        # TODO MHA多头注意力 --> GPA
        self.block_size = GPTConfig().batch_size
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # linear (4 -> 8)； weight shape 是记上是 8 * 4，
        # 所以 embedding weight 和 lm_head weight 是共享的
        # 这里学习一下 tie weight。
        # 这是为了减少参数，加快训练；（现在 25的 SLM 很多都这样做了，注意⚠️）
        self.token_embedding_table.weight = self.lm_head.weight

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # 这里使用的是正态分布初始化
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx 是输入的 token ids
        batch, seq_len = idx.size()
        token_emb = self.token_embedding_table(idx)

        # seq 长度是这次输入的最大长度
        pos_emb = self.position_embedding_table(
            # 要确保 位置编码和输入的 idx 在同一个设备上
            torch.arange(seq_len, device=idx.device)
        )
        # 有一个经典题目：为什么 embedding 和 position 可以相加？
        x = token_emb + pos_emb   # shape is (batch, seq_len, n_embd)
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)   # shape is (batch, seq_len, vocab_size)
        
        if targets is None:
            loss = None
        else:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch * seq_len, vocab_size)
            targets = targets.view(batch * seq_len)
            loss = F.cross_entropy(logits, targets)  # 这里是每个时间 T 步的平均loss, 不受batch大小影响
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # 如果序列太长，只取最后 block_size 个token
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # 获取预测
            logits, _ = self(idx_cond)
            # 只关注最后一个时间步的预测
            logits = logits[:, -1, :]  # becomes (B, vocab_size)
            # 应用softmax获取概率
            probs = F.softmax(logits, dim=-1)   # (B, vocab_size)
            # 采样下一个token
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # 附加到序列上
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx, probs

In [24]:
model = GPT(GPTConfig())
model = model.to(device)

In [25]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6} M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# 设置 cosine 学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

Total parameters: 4.422784 M


In [26]:
# 训练循环
def train(model, optimizer, scheduler, train_loader, device):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        # 将数据移到设备上
        x, y = x.to(device), y.to(device)
        
        # 前向传播
        logits, loss = model(x, targets=y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 调整学习率
        scheduler.step()
        
        total_loss += loss.item()
        
        
        if batch_idx % (len(train_loader)//10) == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}')
    

    return total_loss

def eval(model, val_loader, device):
    # 验证
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, targets=y)
            val_loss += loss.item()
    return val_loss


for epoch in range(1000):
    train_loss = train(model, optimizer, scheduler, train_loader, device)
    writer.add_scalar('Average Loss/train', train_loss/len(train_loader), epoch)
    val_loss = eval(model, val_loader, device)
    writer.add_scalar('Average Loss/validate', val_loss/len(val_loader), epoch)
    print(f'Epoch: {epoch}, Average Train Loss: {train_loss/len(train_loader):.6f}, Average Val Loss: {val_loss/len(val_loader):.6f}')

    # 保存模型
    avg_val_loss = val_loss / len(val_loader)
    if epoch+1 >= 50 and (epoch + 1) % 20 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': avg_val_loss,
        }
        # 保存每个epoch的模型
        os.makedirs(f'./model_save_pth/mini_QWEN_frag_recomd_{current_time}', exist_ok=True)
        torch.save(checkpoint, f'./model_save_pth/mini_QWEN_frag_recomd_{current_time}/model_epoch_{epoch+1}.pt')


    

Epoch: 0, Batch: 0, Loss: 4.277499
Epoch: 0, Batch: 11, Loss: 2.549685
Epoch: 0, Batch: 22, Loss: 2.406250
Epoch: 0, Batch: 33, Loss: 2.355066
Epoch: 0, Batch: 44, Loss: 2.351576


KeyboardInterrupt: 

## 重载 .pth 模型

In [62]:
import torch

# 定义模型、优化器、调度器等
model = GPT(GPTConfig())
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

current_time = '2025-01-27__00-12-04'
current_time = '2025-02-02__11-16-09'   
epoch = '1000'
epoch = '100'

# 设定要加载的模型路径
checkpoint_path = f'./model_save_pth/mini_QWEN_frag_recomd_{current_time}/model_epoch_{epoch}.pt'

# 加载模型
checkpoint = torch.load(checkpoint_path, weights_only=False)

# 加载状态字典
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

# 恢复训练时的epoch和val_loss
epoch = checkpoint['epoch']
val_loss = checkpoint['val_loss']

# 如果你想恢复训练过程，可以继续从这里开始。
x = torch.tensor([[1]]).to(device)
y, probs = model.generate(x, 1)   # (B, T+1)   (B, vocab_size)
print(f'y.shape: {y.shape}')
print(f'y: {y}')
y_list = y.squeeze(0).cpu().numpy().tolist()
y_frag_list = [idx_to_fragment[y] for y in y_list]
print(elem_frag_smi_list)
print(f'y_frag_list: {y_frag_list}')
print(f'probs.shape: {probs.shape}')
y = y[:, -1].squeeze(0)
probs = probs.squeeze(0)
print(f'probs.shape: {probs.shape}')
print(f'probs: {probs}')
sorted_prob_values, sorted_prob_indices = torch.sort(probs, descending=True)
sorted_prob_values = sorted_prob_values.tolist()
sorted_prob_indices = sorted_prob_indices.tolist()
print(f'sorted_prob_values: {sorted_prob_values}')
print(f'sorted_prob_indices: {sorted_prob_indices}')
result_smi_list = [idx_to_fragment[id] for id in sorted_prob_indices]
print(f'result_smi_list: {result_smi_list}')








y.shape: torch.Size([1, 2])
y: tensor([[ 1, 64]], device='cuda:0')
['N', 'c1ccccc1', 'O', 'C', 'CC', 'C=O', 'c1ccncc1', 'Clc1ccccc1', 'C1CCNCC1', 'Fc1ccccc1', 'O=P(O)(O)O', 'FC(F)F', 'CC(C)C', 'NC=O', 'C1CCNC1', 'OC1COCC1O', 'C1CNCCN1', 'C1CC1', 'S', 'C1CCCCC1', 'c1cncnc1', 'C1COCCN1', 'OC1COCC(O)C1O', 'c1cn[nH]c1', 'c1c[nH]cn1', 'c1ccc2[nH]ccc2c1', 'c1ccsc1', 'c1ccc2ccccc2c1', 'O=[SH](=O)c1ccccc1', 'Nc1ncnc2[nH]cnc12', 'Clc1cccc(Cl)c1', 'O=[N+]([O-])c1ccccc1', 'OC1CCOC1', 'C1CCCC1', 'O=P(O[3*])(O)OP(=O)(O)O[3*]', 'Fc1cccc(F)c1', 'c1cscn1', 'c1ccc2[nH]cnc2c1', 'Oc1ccccc1O', 'c1ccc2ncccc2c1', 'O=c1cc[nH]c(=O)[nH]1', 'C[SH](=O)=O', 'Brc1ccccc1', 'N#Cc1ccccc1', 'O=C1CCCN1', 'C=CC', 'c1ccoc1', 'N=C(N)c1ccccc1', 'c1ncc2nc[nH]c2n1', 'C[NH+](C)C', 'c1ccc2ncncc2c1', 'N=C(N)N', 'Nc1cc[nH]c(=O)n1', 'C1CCOC1', 'Clc1ccccc1Cl', 'c1ccc2c(c1)OCO2', 'CC1(C)CN2C(=O)CC2S1', 'c1ccc2[nH]ncc2c1', 'CC1=C(C(=O)O)N2C(=O)CC2SC1', 'C1CCOCC1', 'c1nc[nH]n1', 'NS(=O)(=O)O', 'OC1CCOCC1O', 'c1ccc2c(c1)Nc1ccccc1S2', 