# `Protein Workshop` Tutorial, Part 3 - Adding a New Dataset
![Datasets](../docs/source/_static/box_datasets.png)

## Add a custom dataset to the `Protein Workshop`

In [None]:
%load_ext autoreload
%autoreload 2
#%load_ext blackcellmagic

### Create a new subclass of the `ProteinDataModule` class

#### Reference the `CATHDataModule` below (i.e., `src/datasets/cath.py`) to fill out a custom `src/datasets/my_new_dataset.py`

In [None]:
"""
class CATHDataModule(ProteinDataModule):
    def __init__(
        self,
        path: str,
        batch_size: int,
        format: str = "mmtf",
        pdb_dir: Optional[str] = None,
        pin_memory: bool = True,
        in_memory: bool = False,
        num_workers: int = 16,
        dataset_fraction: float = 1.0,
        transforms: Optional[Iterable[Callable]] = None,
        ) -> None:
        super().__init__()

        self.data_dir = Path(path)
        self.raw_dir = self.data_dir / "raw"
        self.processed_dir = self.data_dir / "processed"
        if not os.path.exists(self.data_dir):
            os.makedirs(self.data_dir)

        if transforms is not None:
            self.transform = self.compose_transforms(
                omegaconf.OmegaConf.to_container(
                    transforms,
                    resolve=True
                    )
                )
        else:
            self.transform = None

        self.in_memory = in_memory

        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.num_workers = num_workers
        self.format = format
        self.pdb_dir = pdb_dir

        self.dataset_fraction = dataset_fraction
        self.excluded_chains: List[str] = self.exclude_pdbs()

    def download(self):
        self.download_chain_list()

    def parse_labels(self):
        pass

    def exclude_pdbs(self):
        return []

    def download_chain_list(self):  # sourcery skip: move-assign
        URL = "http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json"
        if not os.path.exists(self.data_dir / "chain_set_splits.json"):
            logger.info("Downloading dataset index file...")
            wget.download(URL, str(self.data_dir / "chain_set_splits.json"))
        else:
            logger.info("Found existing dataset index")

    @functools.lru_cache
    def parse_dataset(self) -> Dict[str, List[str]]:
        fpath = self.data_dir / "chain_set_splits.json"

        with open(fpath, "r") as file:
            data = json.load(file)

        self.train_pdbs = data["train"]
        logger.info(f"Found {len(self.train_pdbs)} chains in training set")
        logger.info("Removing obsolete PDBs from training set")
        self.train_pdbs = [pdb for pdb in self.train_pdbs if pdb[:4] not in self.obsolete_pdbs.keys()]
        logger.info(f"{len(self.train_pdbs)} remaining training chains")

        logger.info(f"Sampling fraction {self.dataset_fraction} of training set")
        fraction = int(self.dataset_fraction * len(self.train_pdbs))
        self.train_pdbs = random.sample(self.train_pdbs, fraction)

        self.val_pdbs = data["validation"]
        logger.info(f"Found {len(self.val_pdbs)} chains in validation set")
        logger.info("Removing obsolete PDBs from validation set")
        self.val_pdbs = [pdb for pdb in self.val_pdbs if pdb[:4] not in self.obsolete_pdbs.keys()]
        logger.info(f"{len(self.val_pdbs)} remaining validation chains")

        self.test_pdbs = data["test"]
        logger.info(f"Found {len(self.test_pdbs)} chains in test set")
        logger.info("Removing obsolete PDBs from test set")
        self.test_pdbs = [pdb for pdb in self.test_pdbs if pdb[:4] not in self.obsolete_pdbs.keys()]
        logger.info(f"{len(self.test_pdbs)} remaining test chains")
        return data

    def train_dataset(self):
        if not hasattr(self, "train_pdbs"):
            self.parse_dataset()
        pdb_codes = [pdb.split(".")[0] for pdb in self.train_pdbs]
        chains = [pdb.split(".")[1] for pdb in self.train_pdbs]

        return ProteinDataset(
            root=str(self.data_dir),
            pdb_dir=self.pdb_dir,
            pdb_codes=pdb_codes,
            chains=chains,
            transform=self.transform,
            format=self.format,
            in_memory=self.in_memory
        )

    def val_dataset(self) -> ProteinDataset:
        if not hasattr(self, "val_pdbs"):
            self.parse_dataset()

        pdb_codes = [pdb.split(".")[0] for pdb in self.val_pdbs]
        chains = [pdb.split(".")[1] for pdb in self.val_pdbs]

        return ProteinDataset(
            root=str(self.data_dir),
            pdb_dir=self.pdb_dir,
            pdb_codes=pdb_codes,
            chains=chains,
            transform=self.transform,
            format=self.format,
            in_memory=self.in_memory
        )

    def test_dataset(self) -> ProteinDataset:
        if not hasattr(self, "test_pdbs"):
            self.parse_dataset()
        pdb_codes = [pdb.split(".")[0] for pdb in self.test_pdbs]
        chains = [pdb.split(".")[1] for pdb in self.test_pdbs]

        return ProteinDataset(
            root=str(self.data_dir),
            pdb_dir=self.pdb_dir,
            pdb_codes=pdb_codes,
            chains=chains,
            transform=self.transform,
            format=self.format,
            in_memory=self.in_memory
        )

    def train_dataloader(self) -> ProteinDataLoader:
        if not hasattr(self, "train_ds"):
            self.train_ds = self.train_dataset()
        return ProteinDataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def val_dataloader(self) -> ProteinDataLoader:
        if not hasattr(self, "val_ds"):
            self.val_ds = self.val_dataset()
        return ProteinDataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

    def test_dataloader(self) -> ProteinDataLoader:
        if not hasattr(self, "test_ds"):
            self.test_ds = self.test_dataset()
        return ProteinDataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )
"""

### Create a new data config file to accompany the custom `MyNewDataModule`

#### Reference the `CATH` config below (i.e., `configs/dataset/cath.yaml`) to fill out a custom `configs/dataset/my_new_dataset.yaml`

In [None]:
"""
datamodule:
  _target_: "src.datasets.cath.CATHDataModule"
  path: ${env.paths.data}/cath/ # Directory where the dataset is stored
  pdb_dir: ${env.paths.data}/pdb/ # Directory where raw PDB/mmtf files are stored
  format: "mmtf" # Format of the raw PDB/MMTF files
  num_workers: 4 # Number of workers for dataloader
  pin_memory: True # Pin memory for dataloader
  batch_size: 32 # Batch size for dataloader
  dataset_fraction: 1.0 # Fraction of the dataset to use
  transforms: ${transforms} # Transforms to apply to dataset examples
num_classes: 23 # Number of classes
"""

### Use new dataset as either a pre-training or fine-tuning corpus, with or without full-atom context

In [None]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from src.constants import HYDRA_CONFIG_PATH
from src.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(config_name="train", overrides=["encoder=schnet", "task=inverse_folding", "dataset=my_new_dataset", "features=ca_angles", "+aux_task=none"], return_hydra_config=True)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

### Load the custom dataset using the designed config

In [None]:
from src.configs import config

cfg = config.validate_config(cfg)

datamodule = hydra.utils.instantiate(cfg.dataset.datamodule)
datamodule.setup("train")
dl = datamodule.train_dataloader()

for i in dl:
    print(i)
    break

### Either pre-train or fine-tune a model using the custom dataset

In [None]:
from src.finetune import finetune
from src.train import train_model

# train_model(cfg)  # Pre-train a model using the selected data
# finetune(cfg)  # Fine-tune a model using the selected data

### Reconfigure the custom dataset to use side-chain atom context

In [None]:
version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(config_name="train", overrides=["encoder=schnet", "task=inverse_folding", "dataset=my_new_dataset", "features=ca_sc", "+aux_task=none"], return_hydra_config=True)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

### Verify that side-chain torsions are now available as feature inputs

In [None]:
from src.configs import config

cfg = config.validate_config(cfg)

datamodule = hydra.utils.instantiate(cfg)
datamodule.setup("train")
dl = datamodule.train_dataloader()

for i in dl:
    print(i)
    break