In [3]:
!python.exe -m pip install --upgrade pip

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ---------------------------------------- 1.8/1.8 MB 12.5 MB/s eta 0:00:00
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.3.1
    Uninstalling pip-24.3.1:
      Successfully uninstalled pip-24.3.1
Successfully installed pip-25.1.1


In [4]:
!pip install torch-geometric torch torchvision pytorch-lightning requests




In [7]:
!pip install pytorch-lightning==2.5.1



In [28]:
import requests
data_url = "https://raw.githubusercontent.com/RyanWangZf/BioBridge/main/data/BindData/data_config.json"
response = requests.get(data_url)
if response.status_code == 200:
    print("Data from the URL:")
    print(response.text)
else:
    print(f"Failed to fetch data. HTTP Status Code: {response.status_code}")


Data from the URL:
{
    "node_type": {
        "biological_process": 0,
        "gene/protein": 1,
        "disease": 2,
        "effect/phenotype": 3,
        "anatomy": 4,
        "molecular_function": 5,
        "drug": 6,
        "cellular_component": 7,
        "pathway": 8,
        "exposure": 9
    },
    "relation_type": {
        "expression present": 0,
        "synergistic interaction": 1,
        "interacts with": 2,
        "ppi": 3,
        "phenotype present": 4,
        "parent-child": 5,
        "associated with": 6,
        "side effect": 7,
        "contraindication": 8,
        "expression absent": 9,
        "target": 10,
        "indication": 11,
        "enzyme": 12,
        "transporter": 13,
        "off-label use": 14,
        "linked to": 15,
        "phenotype absent": 16,
        "carrier": 17
    },
    "emb_dim": {
        "biological_process": 768,
        "cellular_component": 768,
        "disease": 768,
        "drug": 512,
        "molecular_functio

In [24]:
#!/usr/bin/env python3

'''
This module contains the BioBridgeDataModule class, which is a subclass of LightningDataModule.
It loads the BioBridge dataset (a subset of PrimeKG) and stores the data in a format that can 
be accessed outside of the Jupyter notebook.
'''

from typing import Any, Dict, Optional
import torch
import torch_geometric.data as geom_data
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import os
import requests
import numpy as np
import zipfile
import shutil
import json
import gzip

class BioBridgeDataModule(LightningDataModule):
    """`LightningDataModule` for the BioBridge dataset.
    
    The BioBridge dataset is a subset of PrimeKG, enriched with multi-modal features
    and node embeddings, used for tasks such as biomedical entity prediction, cross-modal retrieval, 
    and multimodal question answering.
    """

    def __init__(
        self,
        data_dir: str = "/tmp/biobridge",  # Path to save the data
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
    ) -> None:
        """Initialize a `BioBridgeDataModule`.

        Args:
            data_dir: The data directory to download and store the dataset.
            batch_size: The batch size for data loading.
            num_workers: Number of workers for data loading.
            pin_memory: Whether to pin memory during data loading.
        """
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data = None
        self.batch_size_per_device = batch_size
        self.data_dir = data_dir
    # def prepare_data(self) -> None:
    #     """
    #     Download and prepare the BioBridge data.
    #     This will be called on a single process to download and preprocess the data.
    #     """
    #     if not os.path.exists(self.data_dir):
    #             os.makedirs(self.data_dir)

    #         # Download the JSON file
    #     data_url = "https://raw.githubusercontent.com/RyanWangZf/BioBridge/main/data/BindData/data_config.json"
    #     json_path = os.path.join(self.data_dir, "data_config.json")
    #     if not os.path.exists(json_path):
    #         print("Downloading BioBridge dataset...")
    #         with requests.get(data_url, stream=True) as r:
    #             with open(json_path, "wb") as f:
    #                 shutil.copyfileobj(r.raw, f)
    #         print(f"Dataset downloaded to {json_path}.")

    #     # Generate or download the .npy files
    #     node_features_path = os.path.join(self.data_dir, "node_features.npy")
    #     edge_index_path = os.path.join(self.data_dir, "edges.npy")
    #     if not os.path.exists(node_features_path) or not os.path.exists(edge_index_path):
    #         print("Generating or downloading .npy files...")
    #     # Add logic to generate or download the .npy files
    #     # Example: Parse the JSON and create .npy files
    #     raise NotImplementedError("Logic to generate .npy files is not implemented.")


    def prepare_data(self) -> None:
        """
        Download and prepare the BioBridge data.
        This will be called on a single process to download and preprocess the data.
        """
        if not os.path.exists(self.data_dir):
            os.makedirs(self.data_dir)

        # Download the JSON file
        data_url = "https://raw.githubusercontent.com/RyanWangZf/BioBridge/main/data/BindData/data_config.json"
        json_path = os.path.join(self.data_dir, "data_config.json")
        if not os.path.exists(json_path):
            print("Downloading BioBridge dataset...")
            with requests.get(data_url, stream=True) as r:
                with open(json_path, "wb") as f:
                    shutil.copyfileobj(r.raw, f)
            print(f"Dataset downloaded to {json_path}.")

        # Check if the file is compressed
        if json_path.endswith(".gz"):
            print("Decompressing the JSON file...")
            decompressed_path = json_path[:-3]  # Remove the .gz extension
            with gzip.open(json_path, "rb") as f_in:
                with open(decompressed_path, "wb") as f_out:
                    shutil.copyfileobj(f_in, f_out)
            json_path = decompressed_path

        # Generate the .npy files
        node_features_path = os.path.join(self.data_dir, "node_features.npy")
        edge_index_path = os.path.join(self.data_dir, "edges.npy")
        if not os.path.exists(node_features_path) or not os.path.exists(edge_index_path):
            print("Generating .npy files from JSON...")
            with open(json_path, "r", encoding="utf-8") as f:
                data_config = json.load(f)

            # Example: Extract node features and edges from the JSON
            node_features = data_config["node_features"]  # Replace with the actual key
            edge_index = data_config["edges"]  # Replace with the actual key

            # Save as .npy files
            np.save(node_features_path, np.array(node_features))
            np.save(edge_index_path, np.array(edge_index))
            print(f"Generated .npy files: {node_features_path}, {edge_index_path}")
    def load_biobridge_data(self, data_path: str):
        """
        Load and process the BioBridge data (subset of PrimeKG) into PyTorch Geometric format.
        
        Args:
            data_path: Path where BioBridge data is stored.
        
        Returns:
            PyTorch Geometric data object.
        """
        # Example loading node features and edge list from .npy files
        node_features = np.load(os.path.join(data_path, "node_features.npy"))  # Node features
        edge_index = np.load(os.path.join(data_path, "edges.npy"))  # Edge list (edge_index)

        # Convert data to PyTorch Geometric Data format
        data = geom_data.Data(x=torch.tensor(node_features, dtype=torch.float),
                               edge_index=torch.tensor(edge_index, dtype=torch.long))

        return data

    def setup(self, stage: Optional[str] = None) -> None:
        """
        Load data and set up training, validation, and test sets.
        This is called on every process in DDP (Distributed Data Parallel) training.
        """
        pass  # This function can be extended if needed for train/test splits.

    def train_dataloader(self) -> DataLoader[Any]:
        """
        Create and return the train dataloader.

        Args:
            None

        Returns:
            DataLoader: The train dataloader.
        """
        return geom_data.DataLoader(self.data,
                                    batch_size=self.batch_size_per_device,
                                    num_workers=self.hparams.num_workers,
                                    shuffle=True)

    def val_dataloader(self) -> DataLoader[Any]:
        """
        Create and return the validation dataloader.

        Args:
            None

        Returns:
            DataLoader: The validation dataloader.
        """
        return geom_data.DataLoader(self.data,
                                    batch_size=self.batch_size_per_device,
                                    num_workers=self.hparams.num_workers,
                                    shuffle=False)

    def test_dataloader(self) -> DataLoader[Any]:
        """
        Create and return the test dataloader.

        Args:
            None

        Returns:
            DataLoader: The test dataloader.
        """
        return geom_data.DataLoader(self.data,
                                    batch_size=self.batch_size_per_device,
                                    num_workers=self.hparams.num_workers,
                                    shuffle=False)

    def teardown(self, stage: Optional[str] = None) -> None:
        """
        Cleanup after training, validation, or testing.
        """
        pass

    def state_dict(self) -> Dict[Any, Any]:
        """
        Called when saving a checkpoint. Implement to generate and save the datamodule state.

        Args:
            None

        Returns:
            Dict: A dictionary containing the datamodule state.
        """
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """
        Called when loading a checkpoint. Implement to reload datamodule state.

        Args:
            state_dict: The datamodule state returned by `self.state_dict()`.
        """
        pass


In [25]:
# Load the BioBridge data
data_module = BioBridgeDataModule(data_dir="/tmp/biobridge")
data_module.prepare_data()  # Ensure the data is downloaded
data = data_module.load_biobridge_data(data_module.data_dir)

# Display the number of nodes and edges
print(f"Number of nodes: {data.x.size(0)}")
print(f"Number of edges: {data.edge_index.size(1)}")

Generating .npy files from JSON...


UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1: invalid start byte