# Adding the OAS Dataset: Customizing Dataset Object and Dataloader Functions
This tutorial is the second part of a series focused on adding a new dataset to BioNeMo using the [Observed Antibody Space (OAS)](https://opig.stats.ox.ac.uk/webapps/oas/) database. There are three steps to this task:


1. Preprocessing includes download of the raw data and any additional preparation steps, such as extracting the files. It also includes dividing the data into train, validation, and test splits. The preprocessing step can make use of two BioNeMo base classes, RemoteResource and ResourcePreprocessor, from bionemo.utils.remote and bionemo.data.preprocess, respectively. Their use is optional but they provide some basic functionality which can accelerate development. This step is covered by this tutorial. This objective was accomplished by the first tutorial, Downloading and Preprocessing.

2. Development of the new dataset class. Here, the NeMo dataset class CSVMemMapDataset will be used. This step was covered in the last tutorial, Modifying the Dataset Class.

3. Modification of the dataloader classes. This tutorial will cover customizing DataLoader objects using the newly created OAS datasets. This will include specifics on instantiating actual Dataset classes, customizing the collate function, and instantiating a dataloader. We will also review how these steps are executed within the BioNeMo model classes.


## Setup and Assumptions
This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at /workspace/bionemo as described in the Code Development section of the Quickstart Guide. This path will be referred to with the variable BIONEMO_WORKSPACE in the tutorial.

All commands should be executed inside the BioNeMo docker container.

In [1]:
BIONEMO_WORKSPACE = '/workspace/bionemo'

In [2]:
### Utility functions 

from IPython.display import Code
import re
import os
import shutil

BIONEMO_WORKSPACE = '/workspace/bionemo'
def stage_files(tag: str,
                source_directory: str = f'{BIONEMO_WORKSPACE}/examples/oas_dataset'):
    """Stage files for each step of the tutorial"""
    source_path = os.path.join(source_directory, tag)
    
    data_path = os.path.join(BIONEMO_WORKSPACE, 'bionemo/data/preprocess/protein')
    shutil.copyfile(os.path.join(source_path, 'oas_paired_subset_download.sh'), 
                    os.path.join(data_path, 'oas_paired_subset_download.sh'))
    
    preprocess_path = os.path.join(BIONEMO_WORKSPACE, 'bionemo/data/preprocess/protein')
    shutil.copyfile(os.path.join(source_path, 'oas_preprocess.py'), 
                    os.path.join(preprocess_path, 'oas_preprocess.py'))
    
    config_path = os.path.join(BIONEMO_WORKSPACE, 'examples/protein/esm1nv/conf')
    shutil.copyfile(os.path.join(source_path, 'pretrain_oas.yaml'), 
                    os.path.join(config_path, 'pretrain_oas.yaml'))
    
    pretrain_path = os.path.join(BIONEMO_WORKSPACE, 'examples/protein/esm1nv')
    shutil.copyfile(os.path.join(source_path, 'pretrain_oas.py'), 
                    os.path.join(pretrain_path, 'pretrain_oas.py'))

    collate_path = os.path.join(BIONEMO_WORKSPACE, 'bionemo/data/dataloader')
    shutil.copyfile(os.path.join(source_path, 'custom_protein_collate.py'), 
                    os.path.join(collate_path, 'custom_protein_collate.py'))

    model_path = os.path.join(BIONEMO_WORKSPACE, 'bionemo/model/protein/esm1nv')
    shutil.copyfile(os.path.join(source_path, 'custom_esm1nv_model.py'), 
                    os.path.join(model_path, 'custom_esm1nv_model.py'))

    
def show_code(filename: str,
              language: str,
              start_line = None,
              end_line = None,
              end_column = None):
    """Display syntax highlighted section of code"""
    
    with open(filename, 'r') as fh:
        code = fh.readlines()

    if end_line:
        code = code[:end_line]
        code.append('...\n')
    if start_line:
        code = code[start_line:]
        code.insert(0, '...\n')
    if end_column:
        for line in code:
            line = line[:end_column] + '...\n'
        
    code = ''.join(code)
    return Code(data=code, language=language)


def filter_log(logfile_list, regex):
    """Filter a list of log output until a regex match is found"""

    reg = re.compile(regex)
    string_matches = filter(reg.search, logfile_list)
    position_matches = list(map(lambda x: logfile_list.index(x), string_matches))
    logfile_list = logfile_list[position_matches[0]:]
    return '\n'.join(logfile_list)

In [3]:
TUTORIAL_FILE_VERSION = 'step_999_final'
stage_files(TUTORIAL_FILE_VERSION, source_directory=f'{BIONEMO_WORKSPACE}/examples/oas_dataset')

## Customizing a collate function

In the last tutorial we saw how you can modify your yaml file to use a different set of data with existing tooling, in some cases, this isn't enough. The `collate_fn` parameter of pytorch DataLoaders if used for last minute adjustments to batches, including masking, shuffling, batching, padding, and other slight modifications to the input data. In BioNeMo, we build our collate function ontop of collators used for language modeling (`bionemo/data/dataloader/collate.py`). 

The collate function is ultimately injected into the dataloader upon construction. To customize further, we can simply extend the existing `ProteinCollate` class with our own additional collation, followed by a call to the parents method.


In [4]:
filename = f'{BIONEMO_WORKSPACE}/bionemo/data/dataloader/custom_protein_collate.py'
show_code(filename=filename, language='python')

## Injecting a custom collate object into an existing model.

The implemented collate function servers a single purpose, it replaces all characters with the character 'A.' This is both easy to implement and simple to check for correctness. Upon doing so, the batch is passed back into the parent collate function for padding and masking. Next, we will inject this into our esm1nv model to be applied to the dataset. You can see below that this occurs on the `build_pretraining_data_loader` method, which primarily operates on an already existing Dataset object.

In [5]:
filename = f'{BIONEMO_WORKSPACE}/bionemo/model/protein/esm1nv/custom_esm1nv_model.py'
show_code(filename=filename, language='python')

In [6]:
std_out = ! cd {BIONEMO_WORKSPACE}/examples/protein/esm1nv && python pretrain_oas.py ++trainer.max_steps=101
print('\n'.join(std_out))

[NeMo W 2023-08-25 18:46:43 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:46:44 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
    
    See https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
      ret = run_job(
    
[NeMo I 2023-08-25 18:46:44 pretrain_oas:12] 
    
    ************** Experiment configuration ***********
[NeMo I 2023-08-25 18:46:44 pretrain_oas:13] 
    name: esm1nv-oas
    do_training: true
    do_testing: false
    restore_from_path: null
    trainer:
      devices: 1
      num_nodes: 1
      accelerator: gpu
      precision: 16
      logger: false
      enable_checkpointing:

## Creating the Dataset object

Underneath the abstractions we provide, ultimately the dataset is constructed by invoking the relevant NeMo object, specified with `model.data.data_impl` in the config file. Additionally we provide the requisite keyword arguments, specified with `model.data.data_impl_kwargs` field. Look around in NeMo for additional dataset types, or implement your own!

We can do this manually as well!

In [7]:
dataset_paths = [ 
    '/data/OASpaired/processed/heavy/train/x000.csv' ,
    '/data/OASpaired/processed/heavy/train/x001.csv' ,
    '/data/OASpaired/processed/heavy/train/x002.csv' ,
]
# Checkout nemo for examples of other dataset types, or add your own!
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import CSVMemMapDataset
# The kwargs here are taken from our yaml file.
dataset = CSVMemMapDataset(dataset_paths=dataset_paths, header_lines=1, newline_int=10, workers=1, sort_dataset_paths=True, data_sep=',', data_col=1)

for i, item in enumerate(iter(dataset)):
    if i > 10: break
    print(item)

[NeMo W 2023-08-25 18:47:09 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:47:10 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.


[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:343] Processing 3 data files using 1 workers
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:349] Time building 0 / 3 mem-mapped files: 0:00:00.051294
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x000.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x001.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x002.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:117] Time loading 3 mem-mapped files: 0:00:00.005260
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:121] Computing global indices
GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTG

## Testing our new collate function

Before we inject our collate function into a dataloader, lets first take a look at what it actually does. As we saw previously, it simply replaces every character with 'A', this should be visually obvious! To do this, we must also include a tokenizer. This is required by the default language modeling collate function, which tokenizes the input, applies padding, and aligns it for distributed training. We call `collate_fn` with some dummy data and then watch the output transformed to `AAAA..`


In [8]:
from bionemo.data.dataloader.custom_protein_collate import CustomProteinBertCollate

# Some magic to get our NeMo tokenizer, filled with arguments from our config file.
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
tokenizer = get_nmt_tokenizer(
            library='sentencepiece',
            tokenizer_model= '/tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model',
            vocab_file='/tokenizers/vocab/protein_sequence_sentencepiece.vocab',
            legacy=False,
)

# Extra kwargs are again taken from our config file.
collate_fn = CustomProteinBertCollate(tokenizer=tokenizer,
                                                    seq_length=512,
                                                    pad_size_divisible_by_8=True,
                                                    modify_percent=.1, # Fraction of tokens to mask or perturb
                                                    perturb_percent=.5, # Fraction of modified tokens to perturb, 1-perturb_percent is masking probability
                                                    ).collate_fn
collate_fn(['ACTGT', 'ADFASDFA'])

[NeMo I 2023-08-25 18:47:10 tokenizer_utils:191] Getting SentencePiece with model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model


{'text': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [1, 6, 6, 4, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
 'types': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'is_random': tensor([0, 1]),
 'loss_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [1, 6, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
 'padding_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
 'batch': ['AAAAA', 'AAAAAAAA']}

# DataLoader!
Lastly, we must construct a dataloader composed of our collate function and our dataset object. From here, we can iterate over the reuslt and ensure it changed the data in the same way as manually calling the collate function.



In [9]:
from torch.utils.data import DataLoader
print("Before:")
dl = DataLoader(dataset, batch_size=2, shuffle=False)
for i, item in enumerate(dl):
    if i > 10:
        break
    print(item)
print("\n\n\nAfter:")
dl = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
for i, item in enumerate(dl):
    if i > 10:
        break
    print(item)


Before:
['GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACGGGGGAATGATGGTTACTACGAAGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCACTGAACACAGACACCTCACCATGAACTTTGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTGAAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCGCTTTCAGTAGCTATGACATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCATACATTAGTAGTGGTGGTGGTATCACCTACTATCCAGACACTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGA

# Conclusion and further reading.

This concludes our tutorial on including custom data in the BioNeMo framework. Throughout these tutorials we described how to manually update a model with a new dataset, and how those changes propagate throughout the framework. Checkout other Dataset classes and tokenizers in NeMo to learn about further customization.