In [1]:
import sys

# Add manually root '/home/lev/projects/TopoBenchmarkX'
root_path = "/home/lev/projects/TopoBenchmarkX"
if root_path not in sys.path:
    sys.path.append(root_path)

import os.path as osp
from collections.abc import Callable
from typing import Optional

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs

from topobenchmarkx.io.load.download_utils import download_file_from_drive


class CornelDataset(InMemoryDataset):
    r""" """

    URLS = {
        # 'contact-high-school': 'https://drive.google.com/open?id=1VA2P62awVYgluOIh1W4NZQQgkQCBk-Eu',
        "US-county-demos": "https://drive.google.com/file/d/1FNF_LbByhYNICPNdT6tMaJI9FxuSvvLK/view?usp=sharing",
    }

    FILE_FORMAT = {
        # 'contact-high-school': 'tar.gz',
        "US-county-demos": "zip",
    }

    RAW_FILE_NAMES = {}

    def __init__(
        self,
        root: str,
        name: str,
        parameters: dict = None,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = True,
        use_node_attr: bool = False,
        use_edge_attr: bool = False,
    ) -> None:
        self.name = name.replace("_", "-")

        super().__init__(
            root, transform, pre_transform, pre_filter, force_reload=force_reload
        )

        # Step 3:Load the processed data
        # After the data has been downloaded from source
        # Then preprocessed to obtain x,y and saved into processed folder
        # We can now load the processed data from processed folder

        # Load the processed data
        data, _, _ = fs.torch_load(self.processed_paths[0])

        # Map the loaded data into
        data = Data.from_dict(data)

        # Step 5: Create the splits and upload desired fold

        # split_idx = random_splitting(data.y, parameters=self.parameters)

        # Assign data object to self.data, to make it be prodessed by Dataset class
        self.data = data

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, self.name, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, "processed")

    @property
    def raw_file_names(self) -> list[str]:
        names = ["", "_2012"]
        return [f"{self.name}_{name}.txt" for name in names]

    @property
    def processed_file_names(self) -> str:
        return "data.pt"

    def download(self) -> None:
        """
        Downloads the dataset from the specified URL and saves it to the raw directory.

        Raises:
            FileNotFoundError: If the dataset URL is not found.
        """

        # Step 1: Download data from the source
        self.url = self.URLS[self.name]
        self.file_format = self.FILE_FORMAT[self.name]

        download_file_from_drive(
            file_link=self.url,
            path_to_save=self.raw_dir,
            dataset_name=self.name,
            file_format=self.file_format,
        )

        # Extract the downloaded file if it is compressed
        fs.cp(
            f"{self.raw_dir}/{self.name}.{self.file_format}", self.raw_dir, extract=True
        )

        # Move the etracted files to the datasets/domain/dataset_name/raw/ directory
        for filename in fs.ls(osp.join(self.raw_dir, self.name)):
            fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))
        fs.rm(osp.join(self.raw_dir, self.name))

        # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}'
        fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}")

    def process(self) -> None:
        """
        Process the data for the dataset.

        This method loads the US county demographics data, applies any pre-processing transformations if specified,
        and saves the processed data to the appropriate location.

        Returns:
            None
        """
        data = load_us_county_demos(self.raw_dir, self.name)

        data = data if self.pre_transform is None else self.pre_transform(data)
        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f"{self.name}({len(self)})"

In [2]:
import numpy as np
import pandas as pd
import torch
import torch_geometric


def load_us_county_demos(path, dataset_name, year=2012):
    edges_df = pd.read_csv(f"{path}/county_graph.csv")
    stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1")

    keep_cols = [
        "FIPS",
        "DEM",
        "GOP",
        "MedianIncome",
        "MigraRate",
        "BirthRate",
        "DeathRate",
        "BachelorRate",
        "UnemploymentRate",
    ]
    # Drop rows with missing values
    stat = stat[keep_cols].dropna()

    # Delete edges that are not present in stat df
    unique_fips = stat["FIPS"].unique()

    src_ = edges_df["SRC"].apply(lambda x: x in unique_fips)
    dst_ = edges_df["DST"].apply(lambda x: x in unique_fips)

    edges_df = edges_df[src_ & dst_]

    # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present
    stat = stat[stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"])]
    stat = stat.reset_index(drop=True)

    # Remove rows where SRC == DST
    edges_df = edges_df[edges_df["SRC"] != edges_df["DST"]]

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(
        np.stack([edges_df["SRC"].to_numpy(), edges_df["DST"].to_numpy()])
    )

    # Make edge_index undirected
    edge_index = torch_geometric.utils.to_undirected(edge_index)

    # Convert edge_index back to pandas DataFrame
    edges_df = pd.DataFrame(edge_index.numpy().T, columns=["SRC", "DST"])

    del edge_index

    # Map stat['FIPS'].unique() to [0, ..., num_nodes]
    fips_map = {fips: i for i, fips in enumerate(stat["FIPS"].unique())}
    stat["FIPS"] = stat["FIPS"].map(fips_map)

    # Map edges_df['SRC'] and edges_df['DST'] to [0, ..., num_nodes]
    edges_df["SRC"] = edges_df["SRC"].map(fips_map)
    edges_df["DST"] = edges_df["DST"].map(fips_map)

    # Get torch_geometric edge_index format
    edge_index = torch.tensor(
        np.stack([edges_df["SRC"].to_numpy(), edges_df["DST"].to_numpy()])
    )

    # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically)
    edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index)

    # Conver mask to index
    index = np.arange(mask.size(0))[mask]
    stat = stat.iloc[index]
    stat = stat.reset_index(drop=True)

    # Get new values for FIPS from current index
    # To understand why please print stat.iloc[[516, 517, 518, 519, 520]] for 2012 year
    # Basically the FIPS values has been shifted
    stat["FIPS"] = stat.reset_index()["index"]

    # Create Election variable
    stat["Election"] = (stat["DEM"] - stat["GOP"]) / (stat["DEM"] + stat["GOP"])

    # Drop DEM and GOP columns and FIPS
    stat = stat.drop(columns=["DEM", "GOP", "FIPS"])

    # Prediction col
    y_col = "Election"  # TODO: Define through config file
    x_col = list(set(stat.columns).difference(set([y_col])))

    stat["MedianIncome"] = (
        stat["MedianIncome"]
        .apply(lambda x: x.replace(",", ""))
        .to_numpy()
        .astype(float)
    )

    x = stat[x_col].to_numpy()
    y = stat[y_col].to_numpy()

    data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)

    return data

In [3]:
a = CornelDataset(
    root="/home/lev/projects/TopoBenchmarkX/datasets/graph", name="US-county-demos"
)

Download complete.


Processing...
Done!


In [9]:
parameters = {
    "data_seed": 0,
    "data_split_dir": "/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos",
    "train_prop": 0.5,
}

In [5]:
a[0].x.shape

(3107, 6)

In [35]:
a = {"k": 1}
a.k

AttributeError: 'dict' object has no attribute 'k'

In [13]:
stat = pd.read_csv(
    "/home/lev/projects/TopoBenchmarkX/datasets/graph/US-county-demos-2012/raw/US-county-demos/county_stats_2016.csv",
    encoding="ISO-8859-1",
)

In [14]:
stat.columns

Index(['FIPS', 'County', 'DEM', 'GOP', 'MedianIncome', 'MigraRate',
       'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'],
      dtype='object')

In [15]:
(
    "Election",
    "MedianIncome",
    "MigraRate",
    "BirthRate",
    "DeathRate",
    "BachelorRate",
    "UnemploymentRate",
)

('Election',
 'MedianIncome',
 'MigraRate',
 'BirthRate',
 'DeathRate',
 'BachelorRate',
 'UnemploymentRate')