<a href="https://colab.research.google.com/github/andrewbowen19/mastersThesisData698/blob/main/Feature_Engineering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
pip install --q torch_geometric

In [41]:
import torch_geometric
import pandas as pd
from pandas import DataFrame
import numpy as np
import networkx as nx
import torch
from torch_geometric.data import HeteroData


import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

## Data Cleaning


In [42]:
class DataCleaner(object):
  def __init__(self):
    pass

  @staticmethod
  def snake_case_columns(df: DataFrame) -> DataFrame:
    """Convert all columns in a pandas DataFrame to `snake_case`"""
    new_cols = [c.lower().replace(" ", "_").replace("-", "_") for c in df.columns]

    df.columns = new_cols
    return df

  @staticmethod
  def trim_whitespace(df: DataFrame) -> DataFrame:
    """Trim all leading and trailing whitespace"""
    for c in df.columns:
      if isinstance(df.dtypes[c], np.dtypes.ObjectDType):
        df[c] = df[c].str.strip()
    return df

  @staticmethod
  def insert_century(date_str):
    """insert "19" into the date string; as all data is 20th century"""
    # Split the date string into components
    month, day, year = date_str.split('/')
    # Insert "19" at the beginning of the year component
    year = '19' + year
    # Join the components back into a date string
    return '/'.join([month, day, year])

  @staticmethod
  def date_parser(df: DataFrame, column_name: str, format: str = "%m/%d/%y") -> DataFrame:
    """Parse date-like columns in a dataframe"""
    df[column_name] = pd.to_datetime(df[column_name], format=format, infer_datetime_format=True)#, century_first=50)
    return df


  @staticmethod
  def replace_value(df: DataFrame, column_name: str, old:str, new:str) -> DataFrame:
    """Replace a value in a column. Will be used to map SRP -> SRS"""

    df[column_name] = df[column_name].str.replace(old, new)
    return df

  @staticmethod
  def assign_id_column(df: DataFrame, data_type: str) -> DataFrame:
    """Use row index (attributable to seq_no in CEDR datasets)"""

    df[f'{data_type}_id'] = df.reset_index().index.values
    return df


cleaner = DataCleaner()


In [43]:
site_mapping = {
    "Hanford Site": {"code": "HANF",
                     "filename_prefix": "Hanford",
                     "site_code": 91},
    "Los Alamos National Laboratory": {"code": "LANL",
                     "filename_prefix": "LANL", "site_code": 52},
    "Savannah River Site": {"code": "SRS",
                     "filename_prefix": "SRS", "site_code": 22},
    "Oak Ridge National Laboratory": {"code": "ORNL",
                     "filename_prefix": "ORNL", "site_code": 7},
}

data_types = ["Buildings", "ChemicalAgents", "IH"]


def construct_df(data_type: str = "Buildings"):
  dfs = []
  for site, info in site_mapping.items():
    prefix = info.get("filename_prefix")
    filename = f"/content/{prefix}-{data_type}.csv"

    dat = pd.read_csv(filename, na_values = "")
    dfs.append(dat)


  df = pd.concat(dfs)

  return df

# Read in datasets and combine across sites
buildings = construct_df("Buildings")
agents = construct_df("ChemicalAgents")
ih_data = construct_df("IH")

In [44]:
# Do some basic data cleaning/preprocessing

# snake_case columns
buildings = cleaner.snake_case_columns(buildings)
agents = cleaner.snake_case_columns(agents)
ih_data = cleaner.snake_case_columns(ih_data)


# Trim whitespace
buildings = cleaner.trim_whitespace(buildings)
agents = cleaner.trim_whitespace(agents)
ih_data = cleaner.trim_whitespace(ih_data)

# Cast dates: need to coerce dates to 1900s
ih_data['date'] = ih_data['date'].apply(cleaner.insert_century)
# ih_data = cleaner.date_parser(ih_data, "date")
ih_data.head()

Unnamed: 0,agent,date,dep_grp,descript,facility,jobtitle,location,quantity,referenc,room,sampleid,site,type,uom,seq_no
0,mercury,12/22/1981,100N,"06686:Mercury survey lift station; n=1, airbor...",HANF,,,0.01,HEX78_83,,122281_001,,BZ,mg/m3,1
1,benzene,4/24/1981,1100,06200:Personnel sampling of gas station attend...,HANF,623,,0.058,HEX78_83,,042481_001,,BZ,ppm,2
2,benzene,7/30/1981,1100,06200:Personnel sampling of gas station attend...,HANF,623,,0.125,HEX78_83,,073081_001,,BZ,ppm,3
3,benzene,7/31/1981,1100,06200:Personnel and breathing zone samples of ...,HANF,623,,0.03,HEX78_83,,073181_001,,BZ,ppm,4
4,formaldehyde,9/1/1981,1100,06476:Impinger samples collected between 1100 ...,HANF,HO5,,0.195,HEX78_83,,090181_001,,BZ,ppm,5


In [45]:
# Reading in death cause codes (ICD-8)
icd = pd.read_csv("/content/icd8.csv")
icd = cleaner.snake_case_columns(icd)
icd = cleaner.trim_whitespace(icd)

# Case control data from CEDR
case_control = pd.read_csv("/content/wing-case-control-DOE.csv")
case_control = cleaner.snake_case_columns(case_control)#.head()

# Lookup underlying cause of death
case_control = case_control.merge(icd[['code', 'disease']], left_on="ca81", right_on="code", how="left")
case_control.head()

Unnamed: 0,ca81,ca82,cancerhx,case,caseid,caseset,id,clas_chg,deny_job,dlo,...,sex,sitein,siteout,smoking,srp,vs,zia,seq_no,code,disease
0,,,0,0,83584,67,22524,9,0,01JUL1990,...,F,15MAY1945,15AUG1945,9,.,A,.,1,,
1,199.0,,0,0,66894,5,40324,9,0,01JUL1968,...,M,15APR1944,15SEP1945,9,.,DC,.,2,,
2,,,0,1,40577,38,40577,9,0,01JUL1971,...,M,15AUG1946,15OCT1948,1,.,DC,.,3,,
3,,,0,0,42315,29,42002,0,0,01JUL1970,...,M,15FEB1947,15APR1947,1,.,DC,.,4,,
4,,,0,1,42080,7,42080,0,1,01JUL1970,...,M,15JUN1964,15SEP1964,1,.,DC,.,5,,


### Sites

We have 4 sites in the *Wing et al., 2000* study. These are assigned facility codes within CEDR, which we'll preserve here as the index for our site nodes. Those codes are listed below:


- Los Alamos National Laboratory (LANL): 52
- Hanford Site (HANF): 91
- Savannah River Site (SRS): 22
- Oak Ridge National Laboratory (ORNL): 7



In [46]:
# Construct dataframe of sites in our study
sites_df = pd.DataFrame(site_mapping).T
sites_df = sites_df.reset_index()
sites_df['site_id'] = sites_df.index.to_numpy(dtype=int)

# Create tensor representing sites - using codes as our index
sites = torch.tensor(data=sites_df.index.unique().tolist(), dtype=torch.long)
sites_df.head()

Unnamed: 0,index,code,filename_prefix,site_code,site_id
0,Hanford Site,HANF,Hanford,91,0
1,Los Alamos National Laboratory,LANL,LANL,52,1
2,Savannah River Site,SRS,SRS,22,2
3,Oak Ridge National Laboratory,ORNL,ORNL,7,3


### Buidings

- Map `facility` to a site code

In [47]:
buildings['facility'] = buildings['facility'].str.replace("SRP", "SRS").str.replace("ZIA", "LANL")
buildings = cleaner.assign_id_column(buildings, "building")
buildings['date'] = pd.to_datetime(buildings['date']).astype(int)

# Lookup site number for each building
buildings = buildings.merge(sites_df[['code', "site_id"]],
                            how="left",
                            left_on="facility", right_on="code").drop("code", axis=1)
buildings.head()

  buildings['date'] = pd.to_datetime(buildings['date']).astype(int)


Unnamed: 0,begdate,date,facility,location,locinfo,referenc,seq_no,building_id,site_id
0,,189216000000000000,HANF,101-H,Deactivated 100-H Area bldg,H186,1,0,0
1,,189216000000000000,HANF,102-H,Deactivated 100-H Area bldg,H186,2,1,0
2,,189216000000000000,HANF,103-B,Fuel element storage,H186,3,2,0
3,,189216000000000000,HANF,103-D,Fuel element storage,H186,4,3,0
4,,189216000000000000,HANF,103-F,Storage (100-F),H186,5,4,0


## Industrial Hygiene data

This data helps us to identify the *presence* of given chemical agents within a site (and in some instances, the specific building/location) at that site.

In [48]:
# Replace old site/facility codes to match standard values
ih_data['facility'] = ih_data['facility'].str.replace("SRP", "SRS").str.replace("ZIA", "LANL")
ih_data['agent'] = ih_data['agent'].str.lower()

# Lookup site number for each agent
ih_data = ih_data.merge(sites_df[['code', "site_code", "site_id"]],
                            how="left",
                            left_on="facility", right_on="code").drop("code", axis=1)

ih_data.head()

Unnamed: 0,agent,date,dep_grp,descript,facility,jobtitle,location,quantity,referenc,room,sampleid,site,type,uom,seq_no,site_code,site_id
0,mercury,12/22/1981,100N,"06686:Mercury survey lift station; n=1, airbor...",HANF,,,0.01,HEX78_83,,122281_001,,BZ,mg/m3,1,91,0
1,benzene,4/24/1981,1100,06200:Personnel sampling of gas station attend...,HANF,623,,0.058,HEX78_83,,042481_001,,BZ,ppm,2,91,0
2,benzene,7/30/1981,1100,06200:Personnel sampling of gas station attend...,HANF,623,,0.125,HEX78_83,,073081_001,,BZ,ppm,3,91,0
3,benzene,7/31/1981,1100,06200:Personnel and breathing zone samples of ...,HANF,623,,0.03,HEX78_83,,073181_001,,BZ,ppm,4,91,0
4,formaldehyde,9/1/1981,1100,06476:Impinger samples collected between 1100 ...,HANF,HO5,,0.195,HEX78_83,,090181_001,,BZ,ppm,5,91,0


#### Chemical Agent Node Features

We need to encode our chemical agent features to numeric values for model training downstream. We can leverage `scikit-learn`'s ColumnTransformer with a `OneHotEncoder` for each feature column in our IH data chemical agent columns. This should result in an array of shape (`n_agnets`, `n_encoded_features`).

In [49]:
# First, construct chemical agent node features from IH data
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

agent_cols = ["agent",
              "date",
              "dep_grp",
              "descript",
              "facility",
              "jobtitle",
              "location",
              "quantity",
              "room",
              "type",
              "uom"]
ct = ColumnTransformer(
        [("ohe", OneHotEncoder(), agent_cols)])

chemical_features = ct.fit_transform(ih_data)
chemical_features = chemical_features.toarray()

# Print out dimensions of agent node features
chemical_features.shape

(3397, 6172)

### IH Data Edge Indices
There's three edge indices for chemical agents that we care about

1. Agent Presence at a **Site** (inferred for those with null `location` values)
2. Agent Presence at a **Building** (via `location` column)
3. \*Employee exposure to an agent (this is handled in **Exposure Records**)

#### Agent <> Site Edge Index

In [50]:
# Now construct Agent => Site Index based on columns
ih_data["seq_no"] = ih_data["seq_no"].astype(int)
ih_data["site_id"] = ih_data["site_id"].astype(int)
agent_site_edge_index = torch.tensor(ih_data[["seq_no", "site_id"]].values.transpose().tolist(), dtype=torch.long)

#### Agent <> Building Edge Index

This may be a bit more sparse, as precise location information is null/not present for most agents listed in out IH data

In [52]:
# Lookup building ID
ih_data = ih_data.merge(buildings[['location', 'building_id']], how='left', on="location")
ih_data.head()

In [57]:
# Let's also
agents_with_buildings = ih_data.loc[~pd.isna(ih_data.building_id)]

agent_building_list = agents_with_buildings[["seq_no", "building_id"]].values.transpose().tolist()
agent_building_edge_index = torch.tensor(agent_building_list, dtype=torch.long)
agent_building_edge_index

tensor([[  46,   59,  154,  ...,  742,  742,  742],
        [ 254,  263,  402,  ..., 1516, 1517, 1518]])

### Employee Data

Lastly, we'll need to deal with our *Employee* node types. Similar to above, we can encode our employee features into numeric values for our employee feature array passed to our graph



In [13]:
# Wrangle case control as needed
employee_df = case_control

# The employee IDs from CEDR need to be remapped from 1...N where N = num employees
employee_df['employee_id'] = employee_df.index.values

employee_df.head()

Unnamed: 0,ca81,ca82,cancerhx,case,caseid,caseset,id,clas_chg,deny_job,dlo,...,sitein,siteout,smoking,srp,vs,zia,seq_no,code,disease,employee_id
0,,,0,0,83584,67,22524,9,0,01JUL1990,...,15MAY1945,15AUG1945,9,.,A,.,1,,,0
1,199.0,,0,0,66894,5,40324,9,0,01JUL1968,...,15APR1944,15SEP1945,9,.,DC,.,2,,,1
2,,,0,1,40577,38,40577,9,0,01JUL1971,...,15AUG1946,15OCT1948,1,.,DC,.,3,,,2
3,,,0,0,42315,29,42002,0,0,01JUL1970,...,15FEB1947,15APR1947,1,.,DC,.,4,,,3
4,,,0,1,42080,7,42080,0,1,01JUL1970,...,15JUN1964,15SEP1964,1,.,DC,.,5,,,4


In [74]:
# Create array for employee outcomes
employee_df['disease'] = employee_df['disease'].fillna(0)

employee_outcomes = employee_df['disease'].values
employee_outcomes.shape

(489,)

In [14]:
# First, construct employee node features
employee_cols = ['ca81', 'ca82', 'cancerhx', 'case', 'caseid', 'caseset', 'id',
       'clas_chg', 'deny_job', 'dlo', 'dob', 'doefac', 'educ_max', 'endstudy',
       'fac1', 'fac2', 'fac3', 'fac_ind', 'farm', 'han', 'hfg_ind', 'hire',
       'icda8', 'ind_age', 'ind_date', 'lanl', 'military', 'nir_exp',
       'nuc_irad', 'nummmfac', 'occ_pnt', 'oh_rec', 'ornl', 'oth_irad',
       'pr_nfac1', 'pr_nfac2', 'pr_nucl', 'pr_rec', 'race', 'rad_rx',
       'sec_spec', 'sec_term', 'sex', 'sitein', 'siteout', 'smoking', 'srp',
       'vs', 'zia', 'code']
ct = ColumnTransformer(
        [("ohe", OneHotEncoder(), employee_cols)])

# Encode columns
employee_features = ct.fit_transform(employee_df)
employee_features = employee_features.toarray()

# Print out dimensions of agent node features
employee_features.shape

(489, 2453)

In [93]:
# Create an employee training mask
np.random.seed(12345)
N = employee_features.shape[0]
employee_training_mask = np.random.choice(a=[0, 1], size=(N))
employee_training_mask.shape

(489,)

#### Employee <> Site Edge Index
We'll need to construct an edge index between an employee and the given site that they worked at. This [CEDR dataset](https://oriseapps.orau.gov/cedr/DataFile.aspx?DataSet=MFMM98A1&DFile=MFMM98A1_4) provides us with info on when/which site given employees worked. For simplicity, we'll be ignoring times worked (though this could be a future edge property)

In [15]:
employment_dates = pd.read_csv("/content/employment-dates.csv")

# Clean up employment dates
employment_dates = cleaner.snake_case_columns(employment_dates)
employment_dates = cleaner.trim_whitespace(employment_dates)

# We don't need termination rows
employment_dates = employment_dates.loc[employment_dates.type != "T"]


employment_dates.head()

Unnamed: 0,id,facility,strtdate,type,seq_no
0,22524,7,15MAY1945,H,1
2,40324,91,15APR1944,H,3
4,40577,7,15AUG1946,H,5
6,42002,7,15FEB1947,H,7
8,42080,7,15JUN1964,H,9


In [16]:
# Lookup site index for employment dates
employment_dates['facility'] = employment_dates['facility'].astype(int)
employment_dates = employment_dates.merge(sites_df[['code', "site_code", "site_id"]],
                            how="left",
                            left_on="facility", right_on="site_code").drop("code", axis=1)

# Lookup updated Employee IDs (from index number above)
employment_dates = employment_dates.merge(employee_df[['id', 'employee_id']], how="left", on="id")

employment_dates.tail()

Unnamed: 0,id,facility,strtdate,type,seq_no,site_code,site_index,employee_id
843,407616,52,15MAY1946,H,1521,52,1,484
844,407666,52,15APR1946,H,1523,52,1,485
845,407819,52,15OCT1948,H,1525,52,1,486
846,409326,52,15DEC1946,H,1527,52,1,487
847,409353,52,15OCT1948,H,1529,52,1,488


In [17]:
# Create edge index linking employees to a given site
employee_site_edge_index = torch.tensor(employment_dates[['employee_id', 'site_id']].values.transpose().tolist(), dtype=torch.long)
employee_site_edge_index

tensor([[  0,   1,   2,  ..., 486, 487, 488],
        [  3,   0,   3,  ...,   1,   1,   1]])

#### Employee <> Agent Site Index

###### Exposure Records
[This CEDR Dataset](https://oriseapps.orau.gov/cedr/VariableFile.aspx?intCell=3&intText=AROM_C) from *Wing et al., 2000* indicates qualitative assessments of exposure. This can be thought of as an adjacency matrix between our employee nodes and our chemical agent nodes. As a simplifying assumption, we'll be ignoring the temporal aspect of this dataset, and assume the presence of chemical agents at a given site *could* result in the exposure to an employee.

In [18]:
exposure = pd.read_csv("/content/wing-exposure-records.csv")


# Clean up exposure data
exposure = cleaner.snake_case_columns(exposure)
exposure = cleaner.trim_whitespace(exposure)
exposure['id'] = exposure.id.astype(int)

# Clean up column names: we only need yes/no exposures
cols_to_drop = [c for c in exposure.columns if "_c" in c]
exposure = exposure.drop(cols_to_drop, axis=1)
exposure.columns = [c.replace("_j", "") for c in exposure.columns]

exposure.head()

Unnamed: 0,id,anymtl,arom,asb,be,cd,elf,hal,hg,micro,ni,othmtl,pb,stat,ur,weld,seq_no
0,22524,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
1,40324,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,2
2,40577,1,1,0,1,1,0,1,0,0,1,1,1,0,1,1,3
3,42002,1,1,1,1,1,1,1,1,0,1,1,1,0,0,0,4
4,42080,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,5


In [19]:
# Read in exposure dataset data dictionary
exposure_data = pd.read_csv("/content/exposure_data_dict.csv")
exposure_data = cleaner.snake_case_columns(exposure_data)
exposure_data = cleaner.trim_whitespace(exposure_data)


# Use only judgement columns, which qualitatively tell whether someone was exposed
exposure_data = exposure_data.loc[exposure_data.varid.str.contains("_J")]
exposure_data['varid'] = exposure_data['varid'].str.lower().str.replace("_j", "")
exposure_data['agent'] = exposure_data['variablename'].str.lower().str.replace(" judgement", "")
exposure_agent_map = exposure_data[['agent', 'varid']]

exposure_agent_map.head()

Unnamed: 0,agent,varid
1,any metal,anymtl
3,aromatic hydrocarbons,arom
5,asbestos,asb
7,beryllium,be
9,cadmium,cd


In [20]:
# lookup agent IDs per encounter
agents['agent_id'] = agents.index.values
exposure_agent_map = exposure_agent_map.merge(agents[["agent", 'agent_id']],
                                              how="left",
                                              left_on="agent", right_on="agent")
# exposure_agent_map['agent_id'] = exposure_agent_map['agent_id'].astype(int)
exposure_agent_map.head()

Unnamed: 0,agent,varid,agent_id
0,any metal,anymtl,
1,aromatic hydrocarbons,arom,
2,asbestos,asb,25.0
3,asbestos,asb,44.0
4,asbestos,asb,204.0


In [21]:
# TODO(abowen) - pivot exposure data
exposure_list = exposure.drop("seq_no", axis=1).melt("id")

exposure_list = exposure_list.merge(exposure_agent_map,
                                    how="inner",
                                    left_on="variable",
                                    right_on="varid")

exposure_list = exposure_list.loc[~exposure_list['agent_id'].isna()]


# Lookup employeed ID (from case control)
exposure_list = exposure_list.merge(employee_df[['id', 'employee_id']], how="left", on="id")

exposure_list.head()

Unnamed: 0,id,variable,value,agent,varid,agent_id,employee_id
0,22524,asb,0,asbestos,asb,25.0,0
1,22524,asb,0,asbestos,asb,44.0,0
2,22524,asb,0,asbestos,asb,204.0,0
3,40324,asb,1,asbestos,asb,25.0,1
4,40324,asb,1,asbestos,asb,44.0,1


In [22]:
# Generate edge index between employees and chemicals
employee_agent_edge_index = torch.tensor(exposure_list[['employee_id', 'agent_id']].values.transpose().tolist(), dtype=torch.long)
employee_agent_edge_index

tensor([[  0,   0,   0,  ..., 488, 488, 488],
        [ 25,  44, 204,  ..., 381,  61, 111]])

# Constructing a Graph Dataset

In [111]:
data = HeteroData()

#### Node Features

We'll add note feature arrays for each node type in our graph:

- Site
- Building
- Chemical Agent
- Employee

In [112]:
# Print out number of nodes and edges
print(f"Number of buildings: {len(buildings)}")
print(f"Number of unique chemical agents: {len(ih_data)}")
print(f"Number of Unique Employees in dataset: {len(case_control['id'].unique())}")

Number of buildings: 1562
Number of unique chemical agents: 4375
Number of Unique Employees in dataset: 487


In [113]:
# Add note features for sites
data['site'].x = sites

# Node features for buildings: TODO(abowen): encode more builidng features maybe?
building_node_features = torch.tensor(buildings[['building_id', 'date']].values.tolist(), dtype=torch.long)
data['building'].x = building_node_features

# Node features for Agents
data['chemical'].x = torch.tensor(chemical_features, dtype=torch.long)

# Node Features for employees
data['employee'].x = torch.tensor(employee_features)
data['employee'].y = torch.tensor(employee_outcomes)
data['employee'].train_mask = torch.tensor(employee_training_mask)


#### Edge Indexes

Now we'll add edge indexes (from the foreign key relationships in our source tables):


- Building <> Site
- Chemical Agent <> Site
- Chemical Agent <> Building
- Employee <> Site

In [114]:
# Now construct the edge index for each type of link

# Builidng <> Site
building_site_edge_index  = torch.tensor(buildings[['building_id', 'site_id']].values.transpose().tolist(), dtype=torch.long)
data['building', 'located_at', 'site'].edge_index = building_site_edge_index


# Chemical Agent <> Site
data['chemical', 'located_at', 'site'].edge_index = agent_site_edge_index

# Chemical Agent <> Building
# This one may be more sparse as we don't know individual buildings for all agents (no `location` provided)
data['chemical', 'contained_in', 'building'].edge_index = agent_building_edge_index

# Employee <> Site
data['employee', 'employed_at', 'site'].edge_index = employee_site_edge_index

# Employee <> Chemical Agent
data['employee', 'exposed_to', 'chemical'].edge_index = employee_agent_edge_index


In [None]:
data

Convert our graph dataset to undirected type.

In [None]:
# Convert to undirected, this adds reverse edges
data = T.ToUndirected()(data)
data

In [115]:
# Ensure our graph dataset is formatted correctly
data.validate()

True

In [116]:

'''
graph = to_networkx(data, to_undirected=False)

node_colors = {'building': ('blue',  5),
               'site': ('green', 20),
               'agent': ("red", 5),
               "employee": ("yellow", 5)}
f, ax = plt.subplots(figsize=(10, 10))
# Draw nodes of buildings and sites
for node_type, node_setting in node_colors.items():
    nodes = [node for node, data in graph.nodes(data=True) if data.get('type') == node_type]
    nx.draw_networkx_nodes(graph, pos=nx.spring_layout(graph),
                           nodelist=nodes, node_color=node_setting[0],
                       node_size=node_setting[1], label=None)
# Draw edges relating bulidings to their site
edges = data['building', 'located_at', 'site'].edge_index.t().tolist()

for building, site in edges:
  graph.add_edge(building, site)
# graph.add_edges_from(edges, type=('building', 'located_at', 'site'))
# nx.draw_networkx_edges(graph, pos=nx.spring_layout(graph))

plt.show()

'''

'\ngraph = to_networkx(data, to_undirected=False)\n\nnode_colors = {\'building\': (\'blue\',  5),\n               \'site\': (\'green\', 20),\n               \'agent\': ("red", 5),\n               "employee": ("yellow", 5)}\nf, ax = plt.subplots(figsize=(10, 10))\n# Draw nodes of buildings and sites\nfor node_type, node_setting in node_colors.items():\n    nodes = [node for node, data in graph.nodes(data=True) if data.get(\'type\') == node_type]\n    nx.draw_networkx_nodes(graph, pos=nx.spring_layout(graph),\n                           nodelist=nodes, node_color=node_setting[0],\n                       node_size=node_setting[1], label=None)\n# Draw edges relating bulidings to their site\nedges = data[\'building\', \'located_at\', \'site\'].edge_index.t().tolist()\n\nfor building, site in edges:\n  graph.add_edge(building, site)\n# graph.add_edges_from(edges, type=(\'building\', \'located_at\', \'site\'))\n# nx.draw_networkx_edges(graph, pos=nx.spring_layout(graph))\n\nplt.show()\n\n'

In [117]:
data

HeteroData(
  site={ x=[4] },
  building={ x=[1562, 2] },
  chemical={ x=[3397, 6172] },
  employee={
    x=[489, 2453],
    y=[489],
    train_mask=[489],
  },
  (building, located_at, site)={ edge_index=[2, 1562] },
  (chemical, located_at, site)={ edge_index=[2, 3397] },
  (chemical, contained_in, building)={ edge_index=[2, 2335] },
  (employee, employed_at, site)={ edge_index=[2, 848] },
  (employee, exposed_to, chemical)={ edge_index=[2, 32274] }
)

# Saving our Graph Dataset

We'll write our output dataset to be downloaded and used in our GNN training module. PyTorch provides a [nice interface](https://pytorch.org/docs/stable/generated/torch.save.html) in which we can do this

In [118]:
# Save graph dataset so it can be downloaded
torch.save(data, "/content/graph-dataset-DOE.pt")

## Model Training

We're now ready to [train a heteregeneous graph neural network against](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html#trainfunc) this data

In [122]:

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero




class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


model = GNN(hidden_channels=64, out_channels=64)
#with torch.no_grad():
#model = to_hetero(model, data.metadata(), aggr='sum')

In [123]:
data

HeteroData(
  site={ x=[4] },
  building={ x=[1562, 2] },
  chemical={ x=[3397, 6172] },
  employee={
    x=[489, 2453],
    y=[489],
    train_mask=[489],
  },
  (building, located_at, site)={ edge_index=[2, 1562] },
  (chemical, located_at, site)={ edge_index=[2, 3397] },
  (chemical, contained_in, building)={ edge_index=[2, 2335] },
  (employee, employed_at, site)={ edge_index=[2, 848] },
  (employee, exposed_to, chemical)={ edge_index=[2, 32274] }
)

In [124]:

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)

    ## TODO(abowen): add training mask for employee nodes
    mask = data['employee'].train_mask
    loss = F.cross_entropy(out['employee'][mask], data['employee'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

In [125]:
# Train model with protocol above
losses = []
for epoch in range(1, 10):
    loss = train()
    losses.append(loss)

print(losses)

ValueError: `MessagePassing.propagate` only supports integer tensors of shape `[2, num_messages]`, `torch_sparse.SparseTensor` or `torch.sparse.Tensor` for argument `edge_index`.

In [None]:
data