In [1]:
import sys

%pip install polars
# Manually specify the path
parent_dir = '/Users/stefanhangler/Documents/Uni/Msc_AI/3_Semester/Seminar_Practical Work/Code.nosync/COATI'

# Add the parent directory to sys.path
sys.path.append(parent_dir)

import os
import pickle # Stefan
from torch.utils.data.datapipes.iter import IterableWrapper # Stefan

from coati.common.util import dir_or_file_exists

Note: you may need to restart the kernel to use updated packages.


In [2]:
class COATI_dataset:
    def __init__(
        self,
        cache_dir,
        fields=["smiles", "atoms", "coords"],
        test_split_mode="row",
        test_frac=0.02,  # in percent.
        valid_frac=0.02,  # in percent.
    ):
        self.cache_dir = cache_dir
        self.summary = {"dataset_type": "coati", "fields": fields}
        self.test_frac = test_frac
        self.fields = fields
        self.valid_frac = valid_frac
        assert int(test_frac * 100) >= 0 and int(test_frac * 100) <= 50
        assert int(valid_frac * 100) >= 0 and int(valid_frac * 100) <= 50
        assert int(valid_frac * 100 + test_frac * 100) < 50
        self.test_split_mode = test_split_mode

    def partition_routine(self, row):
        """ """
        # if not "mod_molecule" in row:
        #     tore = ["raw"]
        #     tore.append("train")
        #     return tore
        # else:
        #     tore = ["raw"]

        #     if row["mod_molecule"] % 100 >= int(
        #         (self.test_frac + self.valid_frac) * 100
        #     ):
        #         tore.append("train")
        #     elif row["mod_molecule"] % 100 >= int((self.test_frac * 100)):
        #         tore.append("valid")
        #     else:
        #         tore.append("test")

        #     return tore
        # Stefan --->
        partition = row.get('partition', 'raw')
        return [partition]

    def get_data_pipe(
        self,
        rebuild=False,
        batch_size=32,
        partition: str = "raw",
        required_fields=[],
        distributed_rankmod_total=None,
        distributed_rankmod_rank=1,
        xform_routine=lambda X: X,
    ):
        """
        Look for the cache locally
        then on s3 if it's not available locally
        then return a pipe to the data.
        """
        # print(f"trying to open a {partition} datapipe for...")
        # if (
        #     not dir_or_file_exists(os.path.join(self.cache_dir, S3_PATH, "0.pkl"))
        # ) or rebuild:
        #     makedir(self.cache_dir)
        #     query_yes_no(
        #         f"Will download ~340 GB of data to {self.cache_dir} . This will take a while. Are you sure?"
        #     )
        #     copy_bucket_dir_from_s3(S3_PATH, self.cache_dir)

        # pipe = (
        #     FileLister(
        #         root=os.path.join(self.cache_dir, S3_PATH),
        #         recursive=False,
        #         masks=["*.pkl"],
        #     )
        #     .shuffle()
        #     .open_files(mode="rb")
        #     .unstack_pickles()
        #     .unbatch()
        #     .shuffle(buffer_size=200000)
        # )
        # pipe = pipe.ur_batcher(
        #     batch_size=batch_size,
        #     partition=partition,
        #     xform_routine=xform_routine,
        #     partition_routine=self.partition_routine,
        #     distributed_rankmod_total=distributed_rankmod_total,
        #     distributed_rankmod_rank=distributed_rankmod_rank,
        #     direct_mode=False,
        #     required_fields=self.fields,
        # )
        # return pipe

        print(f"trying to open a {partition} datapipe for...")

        # Path to your preprocessed pickle file
        pickle_file = os.path.join(self.cache_dir, "train_valid_test_guacamol.pkl")

        if not dir_or_file_exists(pickle_file):
            raise FileNotFoundError(f"{pickle_file} does not exist. Please ensure the file is in the correct location.")

        # Load the pickle file
        with open(pickle_file, 'rb') as f:
            data = pickle.load(f)

        # Filter data based on the partition
        filtered_data = [row for row in data if row.get('partition') == partition]

        # Create an IterableWrapper from the filtered data
        pipe = IterableWrapper(filtered_data)

        # Shuffle, batch, and transform the data
        pipe = pipe.shuffle().batch(batch_size).map(xform_routine)

        return pipe


In [3]:
import torch.multiprocessing as mp
%pip install boto3
%pip install rdkit
from coati.training.train_coati import train_autoencoder, do_args
import os
import inspect

from coati.data.dataset import COATI_dataset
import functools

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [4]:
args = do_args()
args.nodes = 1  # total num nodes.
args.nr = 0  # rank of this node.
# note args.gpus will default to the # gpus on this node.
args.data_parallel = True

args.test_frac = 0.02
args.valid_frac = 0.0
args.n_layer_e3gnn = 5
args.n_hidden_e3nn = 256
args.msg_cutoff_e3nn = 12.0
args.n_hidden_xformer = 256
args.n_embd_common = 256
args.n_layer_xformer = 16
args.n_head = 16
args.max_n_seq = 250  # max the model can forward
#    args.n_seq = 90 # max allowed in training.
args.n_seq = 80  # max allowed in training.
args.biases = True
args.torch_emb = False
args.norm_clips = True
args.norm_embed = False
args.token_mlp = True

args.tokenizer_vocab = "mar"
args.p_dataset = 0.2
args.p_formula = 0.0
args.p_fim = 0.0
args.p_graph = 0.0
args.p_clip = 0.9
args.p_clip_emb_smi = 0.5
args.p_randsmiles = 0.3
args.batch_size = 160

args.online = False  # Possible offline training of an end-to-end clip
args.lr = 5.0e-4
args.weight_decay = 0.1

args.dtype = "float"
args.n_epochs = 25
args.clip_grad = 10
args.test_interval = 2
args.debug = False

args.resume_optimizer = False
# resume from checkpoint file
# args.resume_document = ''

args.ngrad_to_save = 2e6

# output logs
args.output_dir = "./logs/"
# where to save model checkpoints
args.model_dir = "./model_ckpts/"
# where to save dataset cache
args.data_dir = "./"
args.model_filename = "coati_grande"

COATI_dataset(cache_dir=args.data_dir).get_data_pipe()

new shuffle csv pipeline


UrBatcher