In [8]:
import pandas as pd
import numpy as np
import random
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from rdkit import Chem
import sys
from deepchem.feat.smiles_tokenizer import SmilesTokenizer
from minGPT.pipeline import minGPT

## Data preprocessing

In [None]:
pipeline = minGPT()
data_config = pipeline.get_default_data_config()
data_config.file_path = "minGPT/htp_md.csv"
data_config.block_size = 64

print(data_config)
train_dataset, test_dataset = pipeline.data_preprocessing(data_config)

In [10]:
## Model initializing

In [None]:
# Model
model_config = pipeline.get_default_model_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
pipeline.load_model(model_config)

## Training configuring

In [None]:
# Train
train_config = pipeline.get_default_train_config()
print("--------Training configuration--------")
print(train_config)


print(train_config.device)
train_config.max_iters = 10000
train_config.ckpt_path = "./minGPT/ckpts/"
# Uncomment the following line if load from pre-trained model chkpts
# train_config.pretrain = "./ckpts/10000.pt"

## Define call back function
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}, val loss {trainer.loss_val.item():.5f}")

train_config.call_back = batch_end_callback
# Uncomment the following line to start training
# loss = pipeline.train(train_config)

## Generating with model

In [None]:
generate_config = pipeline.get_default_generate_config()
generate_config.ckpt_path = "./minGPT/ckpts/10000.pt"
assert generate_config.task == data_config.task
print(generate_config)

results = pipeline.generate(generate_config)

## Evaluate model
Calculate the scores for 
**uniqueness, novelty, validity, synthesibility, similarity, diversity**, respectively

In [None]:
print(pipeline.evaluate())