<a href="https://colab.research.google.com/github/Lily-2002/Retrosynthetic_Planning/blob/main/train_mT5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
!pip install rdkit
!pip install sentencepiece



In [14]:
from rdkit.Chem import AllChem
import numpy as np
from rdkit import Chem
from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd
from tqdm import tqdm,trange
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, MT5ForConditionalGeneration,MT5ForSequenceClassification,MT5Config

In [15]:
from google.colab import drive
drive.mount('/content/gdrive')


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [16]:
def top_k_acc(preds, gt,k=1):
    # preds = preds.to(torch.device('cpu'))
    probs, idx = torch.topk(preds, k=k)
    idx = idx.cpu().numpy().tolist()# idx前k个最大的值
    gt = gt.cpu().numpy().tolist()
    num = preds.size(0)
    correct = 0
    for i in range(num):
        for id in idx[i]:
            if id == gt[i]:
                correct += 1
    return correct, num

In [17]:
def tokenizers(X):
  from transformers import MT5Tokenizer, MT5ForSequenceClassification
  tokenizer = MT5Tokenizer.from_pretrained("/content/gdrive/MyDrive/mT5-small")
  inputs = tokenizer(X,padding=True,truncation=True,max_length=256,return_tensors="pt")
  return inputs

In [18]:
class OnestepDataset(Dataset):
    def __init__(self, X, y):
        super(OnestepDataset, self).__init__()
        self.X = X
        self.y = y
        # self.fp_dim = fp_dim
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx],self.y[idx]

In [19]:
def dataset_iterator(X,y,
          batch_size=1024,
          shuffle=True
          ):
    dataset = OnestepDataset(X,y)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    import torch
    import torch.nn.functional as F

    def collate_fn(batch):
        X, y = zip(*batch)
        from transformers import MT5Tokenizer, MT5ForSequenceClassification
        tokenizer = MT5Tokenizer.from_pretrained("/content/gdrive/MyDrive/mT5-small")
        inputs = tokenizer(X,padding=True,truncation=True,max_length=256,return_tensors="pt")
        modified_y = [label - 1 for label in y]
        return inputs, torch.tensor(modified_y)


    return DataLoader(train_dataset,
              batch_size=batch_size,
              shuffle=shuffle,
              collate_fn = collate_fn
              ),DataLoader(val_dataset,
                    batch_size = batch_size,
                    shuffle= shuffle,
                    collate_fn = collate_fn)

In [20]:
def load_csv(path):
    X, y = [], []
    df = pd.read_csv(path)
    num = len(df)
    rnx_smiles = list(df['reactions'])
    tnx_class = list(df['class'])
    del df
    for i in tqdm(range(num)):
        rxn = rnx_smiles[i]
        product = rxn.strip().split('>')[-1]
        X.append(product)
        y.append(tnx_class[i])
    return X, y

In [21]:
def train_one_epoch(model, train_loader,
          optimizer,
          device,
          loss_fn,
          it):
    losses = []
    model.train()
    # print(train_loader)
    for X_batch, y_batch in tqdm(train_loader):
        y_batch = y_batch.to(device).long()
        optimizer.zero_grad()
        outputs = model(**X_batch, labels=y_batch)
        print(y_batch)
        loss = outputs.loss
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        optimizer.step()
        losses.append(loss.item())
        it.set_postfix(loss=np.mean(losses[-10:]) if losses else None)
    return losses

In [22]:
def eval_one_epoch(model, val_loader,device):
    model.eval()
    eval_top1_correct, eval_top1_num = 0, 0
    eval_top10_correct, eval_top10_num = 0, 0
    eval_top50_correct, eval_top50_num = 0, 0
    loss = 0.0
    for X_batch, y_batch in tqdm(val_loader):
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        print(y_batch)
        with torch.no_grad():
            y_hat = model(**X_batch)
            loss += F.cross_entropy(y_hat,y_batch).item()
            top_1_correct, num1 = top_k_acc(y_hat, y_batch, k=1)
            top_3_correct, num10 = top_k_acc(y_hat, y_batch, k=3)
            top_5_correct, num50 = top_k_acc(y_hat, y_batch, k=5)
            eval_top1_correct += top_1_correct
            eval_top1_num += num1
            eval_top10_correct += top_10_correct
            eval_top10_num += num10
            eval_top50_correct += top_50_correct
            eval_top50_num += num50
    val_1 = eval_top1_correct/eval_top1_num
    val_10 = eval_top10_correct/eval_top10_num
    val_50 = eval_top50_correct/eval_top50_num
    loss = loss / (len(val_loader.dataset))
    return val_1, val_10, val_50, loss

In [23]:
def train_mT5(model,data,
          loss_fn = nn.CrossEntropyLoss(),
          lr = 1e-4,
          batch_size=16,
          epochs=5,
          wd=0,
          saved_model='../model/saved_states'):
    it = trange(epochs)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=wd)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=1e-6)
    X,y = data
    train_loader,val_loader  = dataset_iterator(X,y,batch_size=batch_size)
    best = -1
    for e in it:
        # Iterate batches
        train_one_epoch(model,train_loader,optimizer,device,loss_fn,it)
        ## Do validation after one epoch training.
        val_1,val_10, val_50, loss= eval_one_epoch(model,val_loader,device)
        scheduler.step(loss)
        if best < val_1:
            best = val_1
            state = model.state_dict()
            torch.save(state,saved_model)
        print("\nTop 1: {}  ==> Top 10: {} ==> Top 50: {}, validation loss ==> {}".format(val_1, val_10, val_50, loss))

In [None]:
if __name__ == '__main__':
    X_train,y_train = load_csv("/content/gdrive/MyDrive/USPTO_50K.csv")
    config = MT5Config.from_pretrained("/content/gdrive/MyDrive/mT5-small")
    config.problem_type = "single_label_classification"  # 设置 problem_type
    config.num_labels = 10
    # 使用配置创建 MT5ForSequenceClassification 的实例
    model = MT5ForSequenceClassification(config)
    data = (X_train,y_train)
    train_mT5(model, data, lr=1e-4, batch_size=16, epochs=5, wd=0, saved_model='/content/gdrive/MyDrive/sft_mT5')

100%|██████████| 50016/50016 [00:00<00:00, 1056029.18it/s]
  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/2501 [00:00<?, ?it/s][A

tensor([1, 8, 5, 2, 0, 5, 1, 5, 1, 6, 1, 8, 0, 1, 0, 0])


  0%|          | 0/5 [00:23<?, ?it/s, loss=2.52]
  0%|          | 1/2501 [00:23<16:21:09, 23.55s/it][A

tensor([5, 1, 6, 1, 1, 1, 0, 1, 5, 1, 0, 0, 0, 8, 1, 5])


  0%|          | 0/5 [00:39<?, ?it/s, loss=2.22]
  0%|          | 2/2501 [00:39<13:08:52, 18.94s/it][A

tensor([1, 2, 0, 5, 0, 2, 4, 1, 0, 1, 1, 6, 0, 6, 0, 2])


  0%|          | 0/5 [01:04<?, ?it/s, loss=2.21]
  0%|          | 3/2501 [01:04<15:06:00, 21.76s/it][A

tensor([5, 1, 1, 5, 1, 2, 1, 0, 2, 1, 2, 7, 1, 4, 0, 2])


  0%|          | 0/5 [01:20<?, ?it/s, loss=2.23]
  0%|          | 4/2501 [01:20<13:36:13, 19.61s/it][A

tensor([2, 2, 0, 5, 0, 2, 0, 1, 5, 6, 2, 8, 2, 5, 1, 6])


  0%|          | 0/5 [01:41<?, ?it/s, loss=2.16]
  0%|          | 5/2501 [01:41<13:54:17, 20.05s/it][A

tensor([5, 8, 2, 0, 7, 1, 0, 2, 2, 0, 2, 6, 2, 0, 5, 5])


  0%|          | 0/5 [01:50<?, ?it/s, loss=2.15]
  0%|          | 6/2501 [01:50<11:15:29, 16.24s/it][A

tensor([1, 1, 0, 5, 6, 5, 8, 5, 0, 5, 1, 5, 1, 8, 0, 0])


  0%|          | 0/5 [02:02<?, ?it/s, loss=2.25]
  0%|          | 7/2501 [02:02<10:22:20, 14.97s/it][A

tensor([5, 1, 0, 1, 1, 1, 1, 5, 1, 0, 6, 0, 5, 8, 5, 0])


  0%|          | 0/5 [02:14<?, ?it/s, loss=2.3] 
  0%|          | 8/2501 [02:14<9:38:48, 13.93s/it] [A

tensor([1, 2, 5, 6, 0, 0, 0, 6, 0, 5, 2, 5, 2, 5, 0, 5])


  0%|          | 0/5 [02:27<?, ?it/s, loss=2.22]
  0%|          | 9/2501 [02:27<9:25:30, 13.62s/it][A

tensor([1, 0, 0, 2, 2, 0, 1, 7, 1, 8, 1, 6, 1, 4, 1, 8])


  0%|          | 0/5 [02:39<?, ?it/s, loss=2.25]
  0%|          | 10/2501 [02:38<8:59:45, 13.00s/it][A

tensor([2, 5, 0, 9, 0, 7, 6, 1, 6, 5, 8, 5, 6, 5, 1, 7])


  0%|          | 0/5 [02:49<?, ?it/s, loss=2.25]
  0%|          | 11/2501 [02:49<8:28:21, 12.25s/it][A

tensor([2, 0, 1, 0, 6, 1, 8, 8, 7, 1, 0, 8, 5, 5, 0, 1])


  0%|          | 0/5 [02:59<?, ?it/s, loss=2.24]
  0%|          | 12/2501 [02:59<8:03:14, 11.65s/it][A

tensor([0, 5, 7, 5, 4, 0, 1, 1, 0, 1, 1, 2, 0, 0, 5, 0])


  0%|          | 0/5 [03:11<?, ?it/s, loss=2.2] 
  1%|          | 13/2501 [03:11<8:03:22, 11.66s/it][A

tensor([2, 0, 0, 6, 7, 1, 2, 0, 1, 5, 3, 2, 0, 1, 0, 5])


  0%|          | 0/5 [03:23<?, ?it/s, loss=2.19]
  1%|          | 14/2501 [03:23<8:13:35, 11.91s/it][A

tensor([1, 1, 5, 5, 0, 6, 5, 0, 0, 0, 4, 6, 0, 2, 1, 0])


  0%|          | 0/5 [03:35<?, ?it/s, loss=2.2] 
  1%|          | 15/2501 [03:35<8:14:28, 11.93s/it][A

tensor([5, 2, 0, 5, 6, 0, 0, 1, 0, 5, 0, 0, 5, 8, 1, 6])


  0%|          | 0/5 [03:48<?, ?it/s, loss=2.16]
  1%|          | 16/2501 [03:48<8:19:56, 12.07s/it][A

tensor([2, 0, 0, 0, 4, 1, 0, 5, 5, 5, 5, 2, 0, 0, 0, 0])


  0%|          | 0/5 [04:01<?, ?it/s, loss=2.04]
  1%|          | 17/2501 [04:01<8:35:09, 12.44s/it][A

tensor([0, 5, 1, 1, 2, 0, 5, 0, 5, 5, 0, 0, 0, 2, 5, 5])


  0%|          | 0/5 [04:12<?, ?it/s, loss=1.93]
  1%|          | 18/2501 [04:12<8:13:17, 11.92s/it][A

tensor([5, 0, 5, 5, 0, 1, 1, 6, 1, 1, 0, 5, 6, 5, 8, 2])


  0%|          | 0/5 [04:23<?, ?it/s, loss=1.97]
  1%|          | 19/2501 [04:23<8:01:31, 11.64s/it][A

tensor([5, 5, 6, 0, 4, 7, 5, 1, 0, 5, 1, 1, 1, 5, 1, 0])


  0%|          | 0/5 [04:35<?, ?it/s, loss=1.91]
  1%|          | 20/2501 [04:35<8:13:09, 11.93s/it][A

tensor([1, 8, 1, 0, 2, 1, 1, 0, 0, 6, 0, 0, 5, 2, 2, 0])


  0%|          | 0/5 [04:46<?, ?it/s, loss=1.86]
  1%|          | 21/2501 [04:46<8:01:40, 11.65s/it][A

tensor([1, 1, 0, 0, 6, 6, 0, 4, 6, 1, 2, 2, 4, 1, 2, 0])


  0%|          | 0/5 [04:58<?, ?it/s, loss=1.91]
  1%|          | 22/2501 [04:58<8:03:23, 11.70s/it][A

tensor([0, 0, 8, 0, 2, 0, 1, 0, 8, 2, 0, 0, 6, 5, 2, 6])


  0%|          | 0/5 [05:08<?, ?it/s, loss=1.98]
  1%|          | 23/2501 [05:08<7:40:15, 11.14s/it][A

tensor([1, 0, 0, 1, 0, 0, 6, 5, 1, 0, 0, 5, 5, 1, 8, 0])


  0%|          | 0/5 [05:19<?, ?it/s, loss=1.94]
  1%|          | 24/2501 [05:19<7:31:01, 10.92s/it][A

tensor([1, 1, 0, 1, 0, 5, 0, 0, 1, 0, 1, 2, 0, 2, 1, 2])


  0%|          | 0/5 [05:30<?, ?it/s, loss=1.89]
  1%|          | 25/2501 [05:30<7:40:35, 11.16s/it][A

tensor([5, 0, 0, 1, 5, 1, 0, 0, 1, 0, 0, 1, 1, 5, 4, 5])


  0%|          | 0/5 [05:42<?, ?it/s, loss=1.87]
  1%|          | 26/2501 [05:42<7:42:56, 11.22s/it][A

tensor([6, 1, 5, 1, 5, 5, 5, 5, 0, 5, 5, 8, 4, 4, 5, 0])


  0%|          | 0/5 [05:51<?, ?it/s, loss=1.98]
  1%|          | 27/2501 [05:51<7:24:57, 10.79s/it][A

tensor([2, 2, 2, 1, 1, 6, 5, 5, 5, 6, 2, 2, 0, 0, 6, 0])


  0%|          | 0/5 [06:04<?, ?it/s, loss=2.06]
  1%|          | 28/2501 [06:04<7:44:34, 11.27s/it][A

tensor([6, 2, 5, 0, 5, 2, 2, 0, 5, 8, 0, 5, 2, 0, 2, 0])


  0%|          | 0/5 [06:16<?, ?it/s, loss=2.07]
  1%|          | 29/2501 [06:16<8:01:12, 11.68s/it][A

tensor([0, 8, 0, 0, 0, 0, 5, 1, 0, 2, 5, 0, 6, 1, 6, 5])


  0%|          | 0/5 [06:29<?, ?it/s, loss=2.03]
  1%|          | 30/2501 [06:29<8:13:45, 11.99s/it][A

tensor([1, 5, 6, 5, 1, 0, 2, 4, 0, 5, 6, 1, 5, 0, 0, 5])


  0%|          | 0/5 [06:44<?, ?it/s, loss=2.02]
  1%|          | 31/2501 [06:44<8:54:30, 12.98s/it][A

tensor([6, 1, 0, 7, 1, 7, 1, 2, 5, 0, 8, 5, 1, 6, 1, 8])


  0%|          | 0/5 [06:56<?, ?it/s, loss=2.01]
  1%|▏         | 32/2501 [06:56<8:38:59, 12.61s/it][A

tensor([5, 6, 1, 7, 0, 0, 1, 6, 5, 5, 5, 0, 0, 0, 1, 1])


  0%|          | 0/5 [07:09<?, ?it/s, loss=1.93]
  1%|▏         | 33/2501 [07:09<8:38:29, 12.61s/it][A

tensor([5, 1, 3, 0, 1, 0, 2, 2, 2, 2, 0, 6, 0, 1, 1, 1])


  0%|          | 0/5 [07:20<?, ?it/s, loss=1.96]
  1%|▏         | 34/2501 [07:20<8:17:28, 12.10s/it][A

tensor([0, 8, 5, 5, 2, 5, 5, 6, 3, 6, 5, 2, 6, 0, 5, 1])


  0%|          | 0/5 [07:29<?, ?it/s, loss=2]   
  1%|▏         | 35/2501 [07:29<7:48:45, 11.41s/it][A

tensor([5, 0, 5, 0, 6, 1, 1, 1, 6, 1, 1, 2, 1, 1, 6, 3])


  0%|          | 0/5 [07:42<?, ?it/s, loss=2.04]
  1%|▏         | 36/2501 [07:42<8:04:15, 11.79s/it][A

tensor([0, 1, 6, 0, 5, 5, 8, 0, 0, 2, 4, 0, 6, 7, 5, 0])


  0%|          | 0/5 [07:53<?, ?it/s, loss=2]   
  1%|▏         | 37/2501 [07:53<7:56:44, 11.61s/it][A

tensor([0, 1, 1, 1, 5, 0, 5, 8, 0, 0, 0, 2, 0, 0, 0, 0])


  0%|          | 0/5 [08:07<?, ?it/s, loss=1.97]
  2%|▏         | 38/2501 [08:07<8:16:58, 12.11s/it][A

tensor([0, 1, 1, 5, 0, 5, 0, 0, 0, 1, 0, 0, 5, 0, 2, 1])
