In [1]:
import pandas as pd
import torch
from torch.utils.data import IterableDataset
from chemprop import data, featurizers
from sklearn.preprocessing import StandardScaler
import psutil
import os
import gc
import time
import numpy as np

# **Introduction**

This notebook illustrates the use of torch.utils.data.IterableDataset in order to sequentially load the dataset and handle it. 

**Context:** I want to train a ChemProp model using a dataset of 1 million compounds. While this is not an excessively large dataset, my MacBook M1 with 8GB of RAM struggles to process the entire CSV file into MolecularDatapoints. The system works fine with the CSV file, but struggles with the MolecularDatapoints. As a result, I am looking for an alternative approach to load small subsets of the CSV file sequentially, generate MolecularDatapoints, then create a Dataset and DataLoader, and finally train the model. One of the challenges I face is ensuring the data is shuffled after completing each training epoch. To address this, I found that `torch.utils.data.IterableDataset` is a useful class for my needs.

I started by creating some useful functions to prepare the Chemprop dataset, as outlined in the tutorial.

In [2]:
def datapoint_preparator(df,smiles_column,target_column):
    smis = df.loc[:,smiles_column].values
    ys = df.loc[:,[target_column]].values
            
    datapoints = [data.MoleculeDatapoint.from_smi(smi,y) for smi, y in zip(smis,ys)]
    return datapoints


def dataset_preparator(df, smiles_column, target_column, featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()):
    datapoints = datapoint_preparator(df=df, smiles_column=smiles_column, target_column=target_column)
    dataset = data.MoleculeDataset(datapoints, featurizer=featurizer)
    return dataset
    

# **MAIN PART: IterableMolDatapoints**

In [3]:
class IterableMolDatapoints(IterableDataset):
    '''A class to prepare data for streaming, which is a subclass of IterableDataset. 
    The output is a generator that yields one chemprop.data.datasets.Datum at a time.
    '''

    def __init__(self, df, smiles_column, target_column, scaler = None, size_at_time=100, shuffle=True):
        '''Parameters:
        ----------
        df (pd.DataFrame): A pandas dataframe containing the data.
        smiles_column (str): The column name containing SMILES strings.
        target_column (str): The column name containing the target values.
        scaler (StandardScaler): A StandardScaler object (already fitted) for normalizing the target values.
        size_at_time (int): The number of samples to transfrom into chemprop.data.datasets.Datum at a time.
        shuffle (boolean): If the df is shuffled.'''
        
        super().__init__()
        self.df = df
        self.smiles_column = smiles_column
        self.target_column = target_column
        self.size_at_time = size_at_time
        self.shuffle= shuffle
        self.scaler = scaler

    def __len__(self):
        return len(self.df)

    def __iter__(self):
        '''A function to define iteration logic. It take the whole csv data, then shuffled, then access to only a subset of data at a time for transformation.
        The output is a generator that yields chemprop.data.datasets.Datum and ready to put through DataLoader.
        '''

        if self.shuffle:
            df_shuffled = self.df.sample(frac=1).reset_index(drop=True)
        else:
            df_shuffled = self.df.copy()

        # Transform pandas dataframe to molecule dataset according to size_at_time, prevent overloading memory. This is to balance between memory and speed.
        for i in range(0, len(df_shuffled), self.size_at_time):
            df_at_time = df_shuffled.iloc[i:i + self.size_at_time]
            df_process = dataset_preparator(df=df_at_time, smiles_column=self.smiles_column, target_column=self.target_column)

            if self.scaler != None: 
                df_process.normalize_targets(self.scaler)

            # Handling parallelization manually
            worker_info = torch.utils.data.get_worker_info()
            if worker_info is None: 
                for mol in df_process:
                    yield mol
            else: 
                num_workers = worker_info.num_workers
                worker_id = worker_info.id
                for i, mol in enumerate(df_process):
                    if i % num_workers == worker_id:
                        yield mol

# **Test 1: Memory usage**

In [4]:
# Prepare data
data_path = 'on_the_fly_data.csv'
smiles_column = 'smiles'
target_column = 'docking_score'

df = pd.read_csv(data_path)
df = df.sample(100000)
scaler = StandardScaler().fit(df[[target_column]])

# Function to record memory
def memory_record():
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss / 1024 ** 2  # in MB
    return mem


In [5]:
gc.collect() 
start_time = time.time()
memory_before = memory_record()
iterable_dataset = IterableMolDatapoints(
    df=df,
    smiles_column=smiles_column,
    target_column=target_column,
    size_at_time=100, scaler=None, shuffle=True
)
memory_after =memory_record()
end_time = time.time()
gc.collect() 

print(f'Memory usage to load iterable dataset: {memory_after-memory_before} MB ')
print(f'Time to load iterable dataset: {end_time-start_time} s ')

Memory usage to load iterable dataset: 0.0 MB 
Time to load iterable dataset: 0.0009050369262695312 s 


In [6]:
gc.collect()
start_time = time.time()
memory_before = memory_record()
dataset = dataset_preparator(
    df=df,
    smiles_column=smiles_column,
    target_column=target_column
)
memory_after = memory_record()
end_time = time.time()
gc.collect()

print(f'Memory usage to load map dataset: {memory_after-memory_before} MB ')
print(f'Time to load map dataset: {end_time-start_time} s ')



Memory usage to load map dataset: 793.34375 MB 
Time to load map dataset: 9.01982593536377 s 


# **Test 2: Similarity to chemprop data loader**

In this test, I aim to demonstrate that the function works similarly to the Chemprop data loader. Additionally, we can apply a scaler if necessary; however, it is important to fit the scaler on the entire dataset (Pandas DataFrame) before applying it.

In [7]:
# Prepare Data
smiles_column = 'smiles'
target_column = 'docking_score'

df_train = pd.read_csv('on_the_fly_data.csv')
df_train = df_train.sample(1000)
scaler = StandardScaler().fit(df_train[[target_column]])

Note: To compare similarity, remember to turn off shuffle

In [8]:
# Create map data
map_dataset = dataset_preparator(df_train, smiles_column, target_column)
#map_dataset.normalize_targets(scaler)
map_loader = data.build_dataloader(map_dataset, batch_size=5, shuffle=False)

# Create iterable data
iterable_dataset = IterableMolDatapoints(
    df=df_train,
    smiles_column=smiles_column,
    target_column=target_column,
    size_at_time=5, shuffle=False, #scaler=scaler
)
iterable_loader = data.build_dataloader(iterable_dataset, batch_size=5, shuffle=False)

In [9]:
def compare_loader(loader1, loader2):
    """
    Check if two data loaders produce the same data in the same order.
    
    Parameters:
    - loader1, loader2: DataLoader instances to compare
    
    Returns:
    - bool: True if loaders produce identical data
    """
    
    # If they have the same lengh:
    if len(loader1) != len(loader2):
        print(f"Loaders have different lengths: {len(loader1)} vs {len(loader2)}")
        return False
    
    # Compare each batch attribute
    for i, (batch1, batch2) in enumerate(zip(loader1, loader2)):
        
        # Compare MolGraph objects
        same_nodes = np.array_equal(batch1.bmg.V, batch2.bmg.V)
        same_edges = np.array_equal(batch1.bmg.E, batch2.bmg.E)
        if same_nodes and same_edges:
            print(f"MolGraphs are identical in batch {i}")
        else:
            print(f"MolGraphs are different in batch {i}")
            return False
        
        # Compare targets
        same_target = np.array_equal(batch1.Y, batch2.Y)
        if same_target:
            print(f"Targets are identical in batch {i}")
        else:
            print(f"Targets are different in batch {i}")
            return False
        
        # Compare more attributes if needed
    
    return True


# Test the similarity between the two data loaders:
iterable_loader = data.build_dataloader(iterable_dataset, batch_size=2, shuffle=False)
map_loader = data.build_dataloader(map_dataset, batch_size=2, shuffle=False)

if compare_loader(iterable_loader, map_loader):
    print("The data loaders contain the same data in the same order")
else:
    print("The data loaders differ")

MolGraphs are identical in batch 0
Targets are identical in batch 0
MolGraphs are identical in batch 1
Targets are identical in batch 1
MolGraphs are identical in batch 2
Targets are identical in batch 2
MolGraphs are identical in batch 3
Targets are identical in batch 3
MolGraphs are identical in batch 4
Targets are identical in batch 4
MolGraphs are identical in batch 5
Targets are identical in batch 5
MolGraphs are identical in batch 6
Targets are identical in batch 6
MolGraphs are identical in batch 7
Targets are identical in batch 7
MolGraphs are identical in batch 8
Targets are identical in batch 8
MolGraphs are identical in batch 9
Targets are identical in batch 9
MolGraphs are identical in batch 10
Targets are identical in batch 10
MolGraphs are identical in batch 11
Targets are identical in batch 11
MolGraphs are identical in batch 12
Targets are identical in batch 12
MolGraphs are identical in batch 13
Targets are identical in batch 13
MolGraphs are identical in batch 14
Targ

# **Test 3: Shuffle**

The data in each batch will be different between epochs

In [10]:
# Prepare Data
smiles_column = 'smiles'
target_column = 'docking_score'

df_train = pd.read_csv('on_the_fly_data.csv')
df_train_10 = df_train.sample(10)


# Create iterable data
iterable_dataset = IterableMolDatapoints(
    df=df_train_10,
    smiles_column=smiles_column,
    target_column=target_column,
    size_at_time=5, scaler=None, shuffle=True
)

iterable_train_loader = data.build_dataloader(
    iterable_dataset,
    batch_size=5, shuffle=False) 

print('Data batches with Unscaled target values:')
for epoch in range(2):
    print(f'Epoch {epoch+1}')
    for i, batch in enumerate(iterable_train_loader):
        print(f'Batch {i+1}')
        print(batch.Y)
    print('-'*40)

Data batches with Unscaled target values:
Epoch 1
Batch 1
tensor([[-4.6800],
        [-6.0033],
        [-6.3420],
        [-5.6532],
        [-5.9586]])
Batch 2
tensor([[-6.5905],
        [-5.7284],
        [-8.3186],
        [-7.4662],
        [-7.6936]])
----------------------------------------
Epoch 2
Batch 1
tensor([[-8.3186],
        [-5.9586],
        [-5.7284],
        [-7.6936],
        [-7.4662]])
Batch 2
tensor([[-4.6800],
        [-6.0033],
        [-6.3420],
        [-6.5905],
        [-5.6532]])
----------------------------------------
