### **Step 1: create two subgraphs (opioid users & recovered users) & joined graph - no nutrition tags yet**

In [28]:
import torch
from torch_geometric.data import HeteroData
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from tqdm import tqdm
from utils import *
from transformers import BertTokenizer, BertModel
import random
import pickle
import math
import warnings

warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_path = '../processed_data'

In [2]:
# Configuration flags
PCA_FLAG = False
LESS_IMBALANCE_FLAG = False
POSITIVE_RATIO = 1
ADD_MEDICAL = False
LLM_FLAG = False
SAMPLING_FLAG = False
SAMPLING_SIZE = 100

In [3]:
from transformers import BertTokenizer, BertModel, LlamaTokenizer, LlamaModel
    # Tokenize input and get output from BERT model
if not LLM_FLAG:
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
else:
    tokenizer = LlamaTokenizer.from_pretrained('../llama-2-7b')
    model = LlamaModel.from_pretrained('../llama-2-7b')

In [4]:
def get_bert_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

# TODO: get_llama_embeddings(sentence)

In [5]:
food_ingredients_df = pd.read_csv('../processed_data/food_ingredients.csv', dtype=str)
user_food_df = pd.read_csv('../processed_data/user_food.csv', dtype=str)
df_demo = pd.read_csv('../processed_data/main_table.csv', dtype=str)
user_habit_df = pd.read_csv('../processed_data/habit.csv', dtype=str)
user_medicine_df = pd.read_csv('../processed_data/user_prescription_medicine.csv', dtype=str)

In [6]:
df_demo = df_demo.fillna(0)
df_demo['label'] = df_demo['label'].astype(float).astype(int).astype(str)

In [7]:
counts = df_demo['label'].value_counts()
print(counts)

label
0    92723
1     2728
2      421
Name: count, dtype: int64


In [7]:
random.seed(42)
if LESS_IMBALANCE_FLAG:
    user_food_df_pos = user_food_df.loc[user_food_df['SEQN'].isin(df_demo.loc[df_demo['label'] == '1']['SEQN'].unique())]
    if SAMPLING_FLAG:
        unique_pos_SEQN = random.sample(user_food_df_pos['SEQN'].unique().tolist(), k=SAMPLING_SIZE)
        user_food_df_pos = user_food_df.loc[user_food_df['SEQN'].isin(unique_pos_SEQN)]

    df_demo_pos = df_demo.loc[df_demo['SEQN'].isin(user_food_df_pos['SEQN'].unique())]

    user_food_df_neg = user_food_df.loc[user_food_df['SEQN'].isin(df_demo.loc[df_demo['label'] == '0']['SEQN'].unique())]
    unique_neg_SEQN = random.sample(user_food_df_neg['SEQN'].unique().tolist(), k=POSITIVE_RATIO*len(df_demo_pos))

    user_food_df_neg = user_food_df.loc[user_food_df['SEQN'].isin(unique_neg_SEQN)]
    user_food_df = pd.concat([user_food_df_pos, user_food_df_neg])
    df_demo = df_demo.loc[df_demo['SEQN'].isin(user_food_df['SEQN'].unique())]

In [8]:
user_habit_df = user_habit_df.rename(columns={'habitID': 'habit_id', 'habitDesc': 'habit_desc'})
user_habit_df = user_habit_df.loc[user_habit_df['SEQN'].isin(user_food_df['SEQN'].unique())]
user_medicine_df = user_medicine_df.loc[user_medicine_df['SEQN'].isin(user_food_df['SEQN'].unique())]

food_ingredients_df = food_ingredients_df.loc[food_ingredients_df['food_id'].isin(user_food_df['food_id'].unique())]

In [9]:
# ID formatting
food_ingredients_df['WWEIA_id'] = food_ingredients_df['WWEIA_id'].str.zfill(4)
food_ingredients_df['ingredient_id'] = food_ingredients_df['ingredient_id'].str.zfill(8)
food_ingredients_df['food_id'] = food_ingredients_df['food_id'].str.zfill(10)
user_food_df['food_id'] = user_food_df['food_id'].str.zfill(10)
user_habit_df['habit_id'] = user_habit_df['habit_id'].str.zfill(2)

In [10]:
def create_graph_for_group(df_demo, user_food_df, food_ingredients_df, user_habit_df, user_medicine_df, group_label):
    # Filter data for the specific group
    df_demo_group = df_demo[df_demo['label'] == str(group_label)]
    
    # Filter users who have food records
    users_with_food = user_food_df['SEQN'].unique()
    df_demo_group = df_demo_group[df_demo_group['SEQN'].isin(users_with_food)]
    
    user_food_df_group = user_food_df[user_food_df['SEQN'].isin(df_demo_group['SEQN'])]
    user_habit_df_group = user_habit_df[user_habit_df['SEQN'].isin(df_demo_group['SEQN'])]
    user_medicine_df_group = user_medicine_df[user_medicine_df['SEQN'].isin(df_demo_group['SEQN'])]

    # Create graph
    graph = HeteroData()

    # Add nodes
    unique_user_ids = df_demo_group['SEQN'].unique()
    
    # Ensure food_ids exist in both user_food_df and food_ingredients_df
    unique_food_ids = np.array(list(set(user_food_df_group['food_id']) & set(food_ingredients_df['food_id'])))
    
    # Filter food_ingredients_df to only include the foods present in unique_food_ids
    food_ingredients_df = food_ingredients_df[food_ingredients_df['food_id'].isin(unique_food_ids)]
    
    unique_ingredient_ids = food_ingredients_df['ingredient_id'].unique()
    unique_wweia_ids = food_ingredients_df['WWEIA_id'].unique()
    unique_habit_ids = user_habit_df_group['habit_id'].unique()
    unique_medicine_ids = user_medicine_df_group['RXDDRGID'].unique()

    graph['user']['node_id'] = unique_user_ids
    graph['food']['node_id'] = unique_food_ids
    graph['ingredient']['node_id'] = unique_ingredient_ids
    graph['category']['node_id'] = unique_wweia_ids
    graph['habit']['node_id'] = unique_habit_ids
    graph['medicine']['node_id'] = unique_medicine_ids

    # Create mappings
    user_to_int = {user_id: i for i, user_id in enumerate(unique_user_ids)}
    food_to_int = {food_id: i for i, food_id in enumerate(unique_food_ids)}
    ingredient_to_int = {ingredient_id: i for i, ingredient_id in enumerate(unique_ingredient_ids)}
    wweia_to_int = {wweia_id: i for i, wweia_id in enumerate(unique_wweia_ids)}
    habit_to_int = {habit_id: i for i, habit_id in enumerate(unique_habit_ids)}
    medicine_to_int = {medicine_id: i for i, medicine_id in enumerate(unique_medicine_ids)}

    # Create edges
    graph['food', 'contains', 'ingredient'].edge_index = torch.tensor(
        [[food_to_int[food_id], ingredient_to_int[ingredient_id]] for food_id, ingredient_id in zip(food_ingredients_df['food_id'], food_ingredients_df['ingredient_id'])],
        dtype=torch.long
    ).t().contiguous()

    graph['food', 'belongs_to', 'category'].edge_index = torch.tensor(
        [[food_to_int[food_id], wweia_to_int[category_id]] for food_id, category_id in zip(food_ingredients_df['food_id'], food_ingredients_df['WWEIA_id'])],
        dtype=torch.long
    ).t().contiguous()

    graph['user', 'eats', 'food'].edge_index = torch.tensor(
        [[user_to_int[user_id], food_to_int[food_id]] for user_id, food_id in zip(user_food_df_group['SEQN'], user_food_df_group['food_id']) if food_id in food_to_int],
        dtype=torch.long
    ).t().contiguous()

    graph['user', 'has', 'habit'].edge_index = torch.tensor(
        [[user_to_int[user_id], habit_to_int[habit_id]] for user_id, habit_id in zip(user_habit_df_group['SEQN'], user_habit_df_group['habit_id'])],
        dtype=torch.long
    ).t().contiguous()

    graph['user', 'takes', 'medicine'].edge_index = torch.tensor(
        [[user_to_int[user_id], medicine_to_int[medicine_id]] for user_id, medicine_id in zip(user_medicine_df_group['SEQN'], user_medicine_df_group['RXDDRGID'])],
        dtype=torch.long
    ).t().contiguous()

    # Add node features
    # Preprocess the data to ensure all features are numeric
    feature_columns = df_demo_group.columns.drop(['SEQN', 'label'])
    
    # Convert categorical variables to one-hot encoding
    df_features = pd.get_dummies(df_demo_group[feature_columns], columns=feature_columns)
    
    # Ensure all columns are numeric
    df_features = df_features.apply(pd.to_numeric, errors='coerce')
    
    # Fill NaN values with 0 (or another appropriate value)
    df_features = df_features.fillna(0)

    graph['user'].x = torch.tensor(df_features.values, dtype=torch.float)

    # Add features for other node types (using BERT embeddings)
    for node_type, id_col, desc_col, df in [
        ('food', 'food_id', 'food_desc', food_ingredients_df),
        ('ingredient', 'ingredient_id', 'ingredient_desc', food_ingredients_df),
        ('category', 'WWEIA_id', 'WWEIA_desc', food_ingredients_df),
        ('habit', 'habit_id', 'habit_desc', user_habit_df),
        ('medicine', 'RXDDRGID', 'RXDDRUG', user_medicine_df)
    ]:
        features = []
        for node_id in graph[node_type]['node_id']:
            matching_rows = df[df[id_col] == node_id]
            if not matching_rows.empty:
                desc = matching_rows[desc_col].iloc[0]
                features.append(get_bert_embedding(desc))
            else:
                print(f"Warning: No description found for {node_type} with {id_col}={node_id}. Skipping this node.")
        
        if features:
            graph[node_type].x = torch.tensor(np.array(features), dtype=torch.float)
        else:
            print(f"Warning: No features found for {node_type}. This node type will not have features.")

    return graph, user_to_int, food_to_int, ingredient_to_int, wweia_to_int, habit_to_int, medicine_to_int

In [22]:
import math

def join_graphs(graph_active, graph_recovered, mappings_active, mappings_recovered):
    joined_graph = HeteroData()

    # 1. Combine node IDs and features for each node type
    for node_type in graph_active.node_types:
        active_node_id = graph_active[node_type]['node_id']
        recovered_node_id = graph_recovered[node_type]['node_id']
        
        # Convert node IDs to list if needed
        if isinstance(active_node_id, torch.Tensor):
            active_node_id = active_node_id.tolist()
        if isinstance(recovered_node_id, torch.Tensor):
            recovered_node_id = recovered_node_id.tolist()

        # Filter out NaN or invalid node IDs before conversion
        active_node_id = [node_id for node_id in active_node_id if not (isinstance(node_id, float) and math.isnan(node_id))]
        recovered_node_id = [node_id for node_id in recovered_node_id if not (isinstance(node_id, float) and math.isnan(node_id))]

        # Convert node IDs to integers if they are valid
        try:
            active_node_id = [int(node_id) for node_id in active_node_id]
            recovered_node_id = [int(node_id) for node_id in recovered_node_id]
        except ValueError:
            raise TypeError(f"Node IDs must be convertible to integers, but found non-integer values in {node_type}")
        
        # Add an offset to recovered_node_id to avoid overlaps
        offset = max(active_node_id) + 1 if active_node_id else 0
        recovered_node_id = [node_id + offset for node_id in recovered_node_id]

        # Combine the node IDs
        joined_graph[node_type]['node_id'] = torch.tensor(active_node_id + recovered_node_id)
        
        # If the node type has 'x' features, concatenate them
        if 'x' in graph_active[node_type]:
            active_x = graph_active[node_type].x
            recovered_x = graph_recovered[node_type].x
            
            # Ensure the feature dimensions match, padding if necessary
            if active_x.size(1) != recovered_x.size(1):
                max_dim = max(active_x.size(1), recovered_x.size(1))
                active_x = pad_features(active_x, max_dim)
                recovered_x = pad_features(recovered_x, max_dim)
            
            # Concatenate the features
            joined_graph[node_type].x = torch.cat([active_x, recovered_x])

    # 2. Combine edge indices for each edge type
    for edge_type in graph_active.edge_types:
        src, rel, dst = edge_type
        
        # Get the number of nodes in the source and destination node types (for offsetting)
        offset = {node_type: len(graph_active[node_type]['node_id']) for node_type in graph_active.node_types}
        
        # Extract edge indices
        edge_index_active = graph_active[edge_type].edge_index
        edge_index_recovered = graph_recovered[edge_type].edge_index
        
        # Adjust the edge indices for the recovered graph by adding offsets
        edge_index_recovered[0] += offset[src]  # Offset source node indices
        edge_index_recovered[1] += offset[dst]  # Offset destination node indices if needed
        
        # Concatenate the edges
        joined_graph[edge_type].edge_index = torch.cat([edge_index_active, edge_index_recovered], dim=1)

    # 3. Combine mappings
    joined_mappings = {}
    for node_type in mappings_active.keys():
        max_id = max(mappings_active[node_type].values()) if mappings_active[node_type] else 0  # Get the max ID from the active mappings
        # Shift recovered mappings by the max active ID to avoid ID conflicts
        joined_mappings[node_type] = {**mappings_active[node_type],
                                      **{k: v + max_id + 1 for k, v in mappings_recovered[node_type].items()}}

    return joined_graph, joined_mappings

def pad_features(features, max_dim):
    # pad feature vectors to the same dimension
    padded = torch.zeros((features.size(0), max_dim))
    padded[:, :features.size(1)] = features
    return padded


In [18]:
graph_active, user_to_int_active, food_to_int_active, ingredient_to_int_active, wweia_to_int_active, habit_to_int_active, medicine_to_int_active = create_graph_for_group(df_demo, user_food_df, food_ingredients_df, user_habit_df, user_medicine_df, group_label=1)



In [19]:
mappings_active = {
    'user': user_to_int_active,
    'food': food_to_int_active,
    'ingredient': ingredient_to_int_active,
    'wweia': wweia_to_int_active,
    'habit': habit_to_int_active,
    'medicine': medicine_to_int_active
}

In [20]:
print(graph_active)

HeteroData(
  user={
    node_id=[2413],
    x=[2413, 4022],
  },
  food={
    node_id=[3810],
    x=[3810, 768],
  },
  ingredient={
    node_id=[2357],
    x=[2357, 768],
  },
  category={
    node_id=[164],
    x=[164, 768],
  },
  habit={
    node_id=[54],
    x=[54, 768],
  },
  medicine={
    node_id=[668],
    x=[667, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 14258] },
  (food, belongs_to, category)={ edge_index=[2, 14258] },
  (user, eats, food)={ edge_index=[2, 63955] },
  (user, has, habit)={ edge_index=[2, 25155] },
  (user, takes, medicine)={ edge_index=[2, 14836] }
)


In [21]:
print(graph_active.metadata())

(['user', 'food', 'ingredient', 'category', 'habit', 'medicine'], [('food', 'contains', 'ingredient'), ('food', 'belongs_to', 'category'), ('user', 'eats', 'food'), ('user', 'has', 'habit'), ('user', 'takes', 'medicine')])


In [22]:
graph_recovered, user_to_int_recovered, food_to_int_recovered, ingredient_to_int_recovered, wweia_to_int_recovered, habit_to_int_recovered, medicine_to_int_recovered = create_graph_for_group(df_demo, user_food_df, food_ingredients_df, user_habit_df, user_medicine_df, group_label=2)



In [28]:
mappings_recovered = {
    'user': user_to_int_recovered,
    'food': food_to_int_recovered,
    'ingredient': ingredient_to_int_recovered,
    'wweia': wweia_to_int_recovered,
    'habit': habit_to_int_recovered,
    'medicine': medicine_to_int_recovered
}

In [29]:
print(graph_recovered)

HeteroData(
  user={
    node_id=[406],
    x=[406, 875],
  },
  food={
    node_id=[1957],
    x=[1957, 768],
  },
  ingredient={
    node_id=[1641],
    x=[1641, 768],
  },
  category={
    node_id=[161],
    x=[161, 768],
  },
  habit={
    node_id=[51],
    x=[51, 768],
  },
  medicine={
    node_id=[239],
    x=[238, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 6659] },
  (food, belongs_to, category)={ edge_index=[2, 6659] },
  (user, eats, food)={ edge_index=[2, 11334] },
  (user, has, habit)={ edge_index=[2, 4393] },
  (user, takes, medicine)={ edge_index=[2, 830] }
)


In [30]:
print(graph_recovered.metadata())

(['user', 'food', 'ingredient', 'category', 'habit', 'medicine'], [('food', 'contains', 'ingredient'), ('food', 'belongs_to', 'category'), ('user', 'eats', 'food'), ('user', 'has', 'habit'), ('user', 'takes', 'medicine')])


In [26]:
# Save the two graphs
torch.save(graph_active, '../processed_data/graph_active_users.pt')
torch.save(graph_recovered, '../processed_data/graph_recovered_users.pt')

In [27]:
#  Load the saved graphs
graph_active = torch.load('../processed_data/graph_active_users.pt')
graph_recovered = torch.load('../processed_data/graph_recovered_users.pt')

In [10]:
# Saving graphs and mappings
def save_graphs_and_mappings(graph_active, graph_recovered, mappings_active, mappings_recovered, base_path):
    # Save graphs
    torch.save(graph_active, f'{base_path}/active_users_graph.pt')
    torch.save(graph_recovered, f'{base_path}/recovered_users_graph.pt')
    
    # Save mappings
    with open(f'{base_path}/active_users_mappings.pkl', 'wb') as f:
        pickle.dump(mappings_active, f)
    with open(f'{base_path}/recovered_users_mappings.pkl', 'wb') as f:
        pickle.dump(mappings_recovered, f)

# Loading graphs and mappings
def load_graphs_and_mappings(base_path):
    # Load graphs
    graph_active = torch.load(f'{base_path}/active_users_graph.pt')
    graph_recovered = torch.load(f'{base_path}/recovered_users_graph.pt')
    
    # Load mappings
    with open(f'{base_path}/active_users_mappings.pkl', 'rb') as f:
        mappings_active = pickle.load(f)
    with open(f'{base_path}/recovered_users_mappings.pkl', 'rb') as f:
        mappings_recovered = pickle.load(f)
    
    return graph_active, graph_recovered, mappings_active, mappings_recovered



In [None]:
save_graphs_and_mappings(graph_active, graph_recovered, mappings_active, mappings_recovered, base_path)

In [12]:
graph_active, graph_recovered, mappings_active, mappings_recovered = load_graphs_and_mappings(base_path)

In [16]:
import torch
from torch_geometric.data import HeteroData

def is_numeric(node_id):
    """Check if a node_id can be converted to an integer."""
    try:
        int(node_id)
        return True
    except ValueError:
        return False

def pad_features(features, max_dim):
    """Pad feature vectors to the same dimension."""
    padded = torch.zeros((features.size(0), max_dim))
    padded[:, :features.size(1)] = features
    return padded

def create_reverse_edges(graph, node_type1, rel, node_type2):
    """Create reverse edges for a specific edge type."""
    edge_index = graph[node_type1, rel, node_type2].edge_index
    # Reverse the edge direction
    reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
    # Add the reverse edge to the graph with consistent naming using '_by'
    graph[node_type2, f'{rel}_by', node_type1].edge_index = reversed_edge_index
    print(f"Created reverse edge '{node_type2} -> {rel}_by -> {node_type1}'")

def create_metapaths(graph, node_type1, intermediate_node_type, node_type2, forward_rel, reverse_rel):
    """
    Create metapaths in the form: node_type1 -> intermediate_node_type -> node_type2
    For example: user -> food -> user
    """
    # Check if the edge types exist
    if (node_type1, forward_rel, intermediate_node_type) not in graph.edge_types:
        raise ValueError(f"Edge type '{node_type1} -> {forward_rel} -> {intermediate_node_type}' not found in the graph.")
    if (intermediate_node_type, reverse_rel, node_type2) not in graph.edge_types:
        raise ValueError(f"Edge type '{intermediate_node_type} -> {reverse_rel} -> {node_type2}' not found in the graph.")
    
    # Find the bipartite edges between node_type1 -> intermediate_node_type
    edge_index_1 = graph[node_type1, forward_rel, intermediate_node_type].edge_index
    # Find the bipartite edges between intermediate_node_type -> node_type2
    edge_index_2 = graph[intermediate_node_type, reverse_rel, node_type2].edge_index
    
    # Create an adjacency matrix for the intermediate node
    adj_1 = torch.sparse_coo_tensor(edge_index_1, torch.ones(edge_index_1.size(1)), 
                                    (graph[node_type1].x.size(0), graph[intermediate_node_type].x.size(0)))
    adj_2 = torch.sparse_coo_tensor(edge_index_2, torch.ones(edge_index_2.size(1)),
                                    (graph[intermediate_node_type].x.size(0), graph[node_type2].x.size(0)))

    # Multiply adjacency matrices to find metapath edges
    metapath_adj = adj_1.matmul(adj_2)
    
    # Convert metapath adjacency matrix back to edge index format
    metapath_edge_index = metapath_adj.coalesce().indices()

    # Add metapath to graph
    graph[node_type1, f'metapath_{intermediate_node_type}', node_type2].edge_index = metapath_edge_index

def join_graphs_with_metapaths(graph_active, graph_recovered, mappings_active, mappings_recovered):
    """
    Join the active and recovered graphs, create reverse edges, and create metapaths to connect users from both groups
    if they share the same food or habit.
    """
    joined_graph = HeteroData()

    # 1. Combine node IDs and features for each node type
    for node_type in graph_active.node_types:
        active_node_id = graph_active[node_type]['node_id']
        recovered_node_id = graph_recovered[node_type]['node_id']
        
        # Convert node IDs to list if needed
        if isinstance(active_node_id, torch.Tensor):
            active_node_id = active_node_id.tolist()
        if isinstance(recovered_node_id, torch.Tensor):
            recovered_node_id = recovered_node_id.tolist()

        # Filter out invalid or non-numeric node IDs
        active_node_id = [node_id for node_id in active_node_id if is_numeric(node_id)]
        recovered_node_id = [node_id for node_id in recovered_node_id if is_numeric(node_id)]
        
        # Convert node IDs to integers if valid
        active_node_id = [int(node_id) for node_id in active_node_id]
        recovered_node_id = [int(node_id) for node_id in recovered_node_id]

        # Add an offset to recovered_node_id to avoid overlaps
        offset = max(active_node_id) + 1 if active_node_id else 0
        recovered_node_id = [node_id + offset for node_id in recovered_node_id]

        # Combine the node IDs
        joined_graph[node_type]['node_id'] = torch.tensor(active_node_id + recovered_node_id)
        
        # If the node type has 'x' features, concatenate them
        if 'x' in graph_active[node_type]:
            active_x = graph_active[node_type].x
            recovered_x = graph_recovered[node_type].x
            
            # Ensure the feature dimensions match, padding if necessary
            if active_x.size(1) != recovered_x.size(1):
                max_dim = max(active_x.size(1), recovered_x.size(1))
                active_x = pad_features(active_x, max_dim)
                recovered_x = pad_features(recovered_x, max_dim)
            
            # Concatenate the features
            joined_graph[node_type].x = torch.cat([active_x, recovered_x])

    # 2. Combine edge indices for each edge type
    for edge_type in graph_active.edge_types:
        src, rel, dst = edge_type
        
        # Get the number of nodes in the source and destination node types (for offsetting)
        offset = {node_type: len(graph_active[node_type]['node_id']) for node_type in graph_active.node_types}
        
        # Extract edge indices
        edge_index_active = graph_active[edge_type].edge_index
        edge_index_recovered = graph_recovered[edge_type].edge_index
        
        # Adjust the edge indices for the recovered graph by adding offsets
        edge_index_recovered[0] += offset[src]  # Offset source node indices
        edge_index_recovered[1] += offset[dst]  # Offset destination node indices if needed
        
        # Concatenate the edges
        joined_graph[edge_type].edge_index = torch.cat([edge_index_active, edge_index_recovered], dim=1)

    # 3. Combine mappings (if required)
    joined_mappings = {}
    for node_type in mappings_active.keys():
        max_id = max(mappings_active[node_type].values()) if mappings_active[node_type] else 0  # Get the max ID from the active mappings
        # Shift recovered mappings by the max active ID to avoid ID conflicts
        joined_mappings[node_type] = {**mappings_active[node_type],
                                      **{k: v + max_id + 1 for k, v in mappings_recovered[node_type].items()}}

    # 4. Create reverse edges for metapath creation
    create_reverse_edges(joined_graph, 'user', 'eats', 'food')
    create_reverse_edges(joined_graph, 'user', 'has', 'habit')

    # 5. Create metapaths
    # Metapath 1: user -> food -> user (shared food)
    create_metapaths(joined_graph, 'user', 'food', 'user', 'eats', 'eats_by')

    # Metapath 2: user -> habit -> user (shared habit)
    create_metapaths(joined_graph, 'user', 'habit', 'user', 'has', 'has_by')

    return joined_graph, joined_mappings

joined_graph, joined_mappings = join_graphs_with_metapaths(graph_active, graph_recovered, mappings_active, mappings_recovered)


Created reverse edge 'food -> eats_by -> user'
Created reverse edge 'habit -> has_by -> user'


In [18]:
print(joined_mappings)

{'user': {'21118': 0, '21165': 1, '21178': 2, '21224': 3, '21252': 4, '21270': 5, '21305': 6, '21306': 7, '21329': 8, '21331': 9, '21345': 10, '21388': 11, '21404': 12, '21500': 13, '21566': 14, '21790': 15, '21860': 16, '21886': 17, '21910': 18, '21978': 19, '22090': 20, '22107': 21, '22194': 22, '22202': 23, '22215': 24, '22219': 25, '22296': 26, '22346': 27, '22387': 28, '22474': 29, '22475': 30, '22567': 31, '22629': 32, '22690': 33, '22698': 34, '22781': 35, '22801': 36, '22813': 37, '22819': 38, '22872': 39, '22897': 40, '22913': 41, '22929': 42, '22943': 43, '22961': 44, '23006': 45, '23011': 46, '23024': 47, '23172': 48, '23276': 49, '23293': 50, '23326': 51, '23343': 52, '23375': 53, '23413': 54, '23427': 55, '23521': 56, '23748': 57, '23816': 58, '23908': 59, '23910': 60, '23998': 61, '24005': 62, '24126': 63, '24243': 64, '24248': 65, '24326': 66, '24515': 67, '24583': 68, '24586': 69, '24604': 70, '24647': 71, '24668': 72, '24702': 73, '24757': 74, '24811': 75, '24812': 76,

In [19]:
print(joined_graph.metadata)

<bound method HeteroData.metadata of HeteroData(
  user={
    node_id=[2819],
    x=[2819, 4022],
  },
  food={
    node_id=[5767],
    x=[5767, 768],
  },
  ingredient={
    node_id=[3998],
    x=[3998, 768],
  },
  category={
    node_id=[325],
    x=[325, 768],
  },
  habit={
    node_id=[105],
    x=[105, 768],
  },
  medicine={
    node_id=[0],
    x=[905, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 20917] },
  (food, belongs_to, category)={ edge_index=[2, 20917] },
  (user, eats, food)={ edge_index=[2, 75289] },
  (user, has, habit)={ edge_index=[2, 29548] },
  (user, takes, medicine)={ edge_index=[2, 15666] },
  (food, eats_by, user)={ edge_index=[2, 75289] },
  (habit, has_by, user)={ edge_index=[2, 29548] },
  (user, metapath_food, user)={ edge_index=[2, 4608442] },
  (user, metapath_habit, user)={ edge_index=[2, 5821999] }
)>


In [17]:
print(joined_graph)

HeteroData(
  user={
    node_id=[2819],
    x=[2819, 4022],
  },
  food={
    node_id=[5767],
    x=[5767, 768],
  },
  ingredient={
    node_id=[3998],
    x=[3998, 768],
  },
  category={
    node_id=[325],
    x=[325, 768],
  },
  habit={
    node_id=[105],
    x=[105, 768],
  },
  medicine={
    node_id=[0],
    x=[905, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 20917] },
  (food, belongs_to, category)={ edge_index=[2, 20917] },
  (user, eats, food)={ edge_index=[2, 75289] },
  (user, has, habit)={ edge_index=[2, 29548] },
  (user, takes, medicine)={ edge_index=[2, 15666] },
  (food, eats_by, user)={ edge_index=[2, 75289] },
  (habit, has_by, user)={ edge_index=[2, 29548] },
  (user, metapath_food, user)={ edge_index=[2, 4608442] },
  (user, metapath_habit, user)={ edge_index=[2, 5821999] }
)


In [31]:
print(graph_active)

HeteroData(
  user={
    node_id=[2413],
    x=[2413, 4022],
  },
  food={
    node_id=[3810],
    x=[3810, 768],
  },
  ingredient={
    node_id=[2357],
    x=[2357, 768],
  },
  category={
    node_id=[164],
    x=[164, 768],
  },
  habit={
    node_id=[54],
    x=[54, 768],
  },
  medicine={
    node_id=[668],
    x=[667, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 14258] },
  (food, belongs_to, category)={ edge_index=[2, 14258] },
  (user, eats, food)={ edge_index=[2, 63955] },
  (user, has, habit)={ edge_index=[2, 25155] },
  (user, takes, medicine)={ edge_index=[2, 14836] }
)


In [32]:
print(graph_recovered)

HeteroData(
  user={
    node_id=[406],
    x=[406, 875],
  },
  food={
    node_id=[1957],
    x=[1957, 768],
  },
  ingredient={
    node_id=[1641],
    x=[1641, 768],
  },
  category={
    node_id=[161],
    x=[161, 768],
  },
  habit={
    node_id=[51],
    x=[51, 768],
  },
  medicine={
    node_id=[239],
    x=[238, 768],
  },
  (food, contains, ingredient)={ edge_index=[2, 6659] },
  (food, belongs_to, category)={ edge_index=[2, 6659] },
  (user, eats, food)={ edge_index=[2, 11334] },
  (user, has, habit)={ edge_index=[2, 4393] },
  (user, takes, medicine)={ edge_index=[2, 830] }
)


In [20]:
torch.save(joined_graph, '../processed_data/joined_graph.pt')

#### **Note:** 

For step 1: 

* Right now, the metapath user-food-user is defined if there's at least one food in common between 2 users, and the metapath user-habit-user is defined if there's at least one habit in common between 2 users. 

* In reality, what we should do is to increase the threshold so that we don't end up with too many metapaths. 

* From previous work in the project (see `main.py`): 

    ```
    # metapath thresholds
    UFU_edge_list = generate_neighbors(graph, ('user', 'eats', 'food'), ('food', 'eaten', 'user'), shared_threshold=5)
    UHU_edge_list = generate_neighbors(graph, ('user', 'has', 'habit'), ('habit', 'from', 'user'), shared_threshold=8)
    ```

This logic will be incorporated in step 2 below

### **Step 2: Adding nutrition tags to the subgraphs and the joined graph**

Note: here, we're implemented the logic for the metapaths as follows: 

    ```
    # metapath thresholds
    UFU_edge_list = generate_neighbors(graph, ('user', 'eats', 'food'), ('food', 'eaten', 'user'), shared_threshold=5)
    UHU_edge_list = generate_neighbors(graph, ('user', 'has', 'habit'), ('habit', 'from', 'user'), shared_threshold=8)
    ```

In [25]:
def load_graphs_and_mappings(base_path):
    graph_active = torch.load(f'{base_path}/active_users_graph.pt')
    graph_recovered = torch.load(f'{base_path}/recovered_users_graph.pt')
    
    with open(f'{base_path}/active_users_mappings.pkl', 'rb') as f:
        mappings_active = pickle.load(f)
    with open(f'{base_path}/recovered_users_mappings.pkl', 'rb') as f:
        mappings_recovered = pickle.load(f)
    
    return graph_active, graph_recovered, mappings_active, mappings_recovered

base_path = '../processed_data'
graph_active, graph_recovered, mappings_active, mappings_recovered = load_graphs_and_mappings(base_path)

df_user_tags = pd.read_csv('../processed_data/user_tagging.csv', index_col='SEQN')

# Step 2: Add health tags to graph_active and graph_recovered
def add_health_tags_to_graph(graph, user_mapping, df_user_tags):
    tag_columns = [col for col in df_user_tags.columns if col.startswith('user_')]
    
    # Initialize a new feature matrix for users
    num_users = graph['user'].num_nodes
    num_tags = len(tag_columns)
    new_user_features = torch.zeros((num_users, graph['user'].x.shape[1] + num_tags))
    
    # Copy existing features
    new_user_features[:, :graph['user'].x.shape[1]] = graph['user'].x
    
    # Add health tags
    for seqn, node_id in user_mapping.items():
        if seqn in df_user_tags.index:
            tags = df_user_tags.loc[seqn, tag_columns].values
            new_user_features[node_id, graph['user'].x.shape[1]:] = torch.tensor(tags, dtype=torch.float)
    
    # Update the graph with new user features
    graph['user'].x = new_user_features
    
    return graph

graph_active_with_tags = add_health_tags_to_graph(graph_active, mappings_active['user'], df_user_tags)
graph_recovered_with_tags = add_health_tags_to_graph(graph_recovered, mappings_recovered['user'], df_user_tags)


def join_graphs_with_metapaths(graph_active, graph_recovered, mappings_active, mappings_recovered):
    joined_graph = HeteroData()

    # 1. Combine node IDs and features for each node type
    for node_type in graph_active.node_types:
        active_node_id = graph_active[node_type]['node_id']
        recovered_node_id = graph_recovered[node_type]['node_id']
        
        # Convert node IDs to list if needed
        if isinstance(active_node_id, torch.Tensor):
            active_node_id = active_node_id.tolist()
        if isinstance(recovered_node_id, torch.Tensor):
            recovered_node_id = recovered_node_id.tolist()
        
        # Filter out NaN values and preserve string IDs
        def process_node_ids(node_ids):
            processed = []
            for node_id in node_ids:
                if isinstance(node_id, float) and math.isnan(node_id):
                    continue
                if isinstance(node_id, (int, float)):
                    processed.append(int(node_id))
                else:
                    processed.append(str(node_id))
            return processed

        active_node_id = process_node_ids(active_node_id)
        recovered_node_id = process_node_ids(recovered_node_id)

        # Add an offset to recovered_node_id to avoid overlaps (only for numeric IDs)
        if all(isinstance(node_id, int) for node_id in active_node_id + recovered_node_id):
            offset = max(active_node_id) + 1 if active_node_id else 0
            recovered_node_id = [node_id + offset for node_id in recovered_node_id]

        # Combine the node IDs
        joined_graph[node_type]['node_id'] = active_node_id + recovered_node_id
        
        # If the node type has 'x' features, concatenate them
        if 'x' in graph_active[node_type]:
            active_x = graph_active[node_type].x
            recovered_x = graph_recovered[node_type].x
            
            # Ensure the feature dimensions match
            if active_x.size(1) != recovered_x.size(1):
                max_dim = max(active_x.size(1), recovered_x.size(1))
                active_x = torch.nn.functional.pad(active_x, (0, max_dim - active_x.size(1)))
                recovered_x = torch.nn.functional.pad(recovered_x, (0, max_dim - recovered_x.size(1)))
            
            # Concatenate the features
            joined_graph[node_type].x = torch.cat([active_x, recovered_x])

    # 2. Combine edge indices for each edge type
    for edge_type in graph_active.edge_types:
        src, rel, dst = edge_type
        
        # Get the number of nodes in the source and destination node types (for offsetting)
        offset = {node_type: len(graph_active[node_type]['node_id']) for node_type in graph_active.node_types}
        
        # Extract edge indices
        edge_index_active = graph_active[edge_type].edge_index
        edge_index_recovered = graph_recovered[edge_type].edge_index
        
        # Adjust the edge indices for the recovered graph by adding offsets
        edge_index_recovered = edge_index_recovered.clone()  # Create a copy to avoid modifying the original
        edge_index_recovered[0] += offset[src]  # Offset source node indices
        edge_index_recovered[1] += offset[dst]  # Offset destination node indices if needed
        
        # Concatenate the edges
        joined_graph[edge_type].edge_index = torch.cat([edge_index_active, edge_index_recovered], dim=1)

    # 3. Create metapaths
    def create_metapaths(graph, node_type1, intermediate_node_type, node_type2, forward_rel, reverse_rel, threshold):
        edge_index_1 = graph[node_type1, forward_rel, intermediate_node_type].edge_index
        edge_index_2 = graph[intermediate_node_type, reverse_rel, node_type2].edge_index
        
        # Convert to dense adjacency matrices
        n_users = graph[node_type1].num_nodes
        n_intermediate = graph[intermediate_node_type].num_nodes

        adj_1 = torch.zeros((n_users, n_intermediate), dtype=torch.float)
        adj_1[edge_index_1[0], edge_index_1[1]] = 1

        adj_2 = torch.zeros((n_intermediate, n_users), dtype=torch.float)
        adj_2[edge_index_2[0], edge_index_2[1]] = 1

        # Compute metapath adjacency
        metapath_adj = torch.matmul(adj_1, adj_2)
        
        # Apply threshold condition
        metapath_adj_thresholded = (metapath_adj >= threshold)
        
        # Convert back to edge index format
        metapath_edge_index = metapath_adj_thresholded.nonzero().t()

        graph[node_type1, f'metapath_{intermediate_node_type}', node_type2].edge_index = metapath_edge_index

    # Create reverse edges for metapath creation
    joined_graph['food', 'eats_by', 'user'].edge_index = joined_graph['user', 'eats', 'food'].edge_index[[1, 0]]
    joined_graph['habit', 'has_by', 'user'].edge_index = joined_graph['user', 'has', 'habit'].edge_index[[1, 0]]

    create_metapaths(joined_graph, 'user', 'food', 'user', 'eats', 'eats_by', threshold=5)
    create_metapaths(joined_graph, 'user', 'habit', 'user', 'has', 'has_by', threshold=8)

    return joined_graph

joined_graph_with_tags = join_graphs_with_metapaths(graph_active_with_tags, graph_recovered_with_tags, mappings_active, mappings_recovered)

torch.save(graph_active_with_tags, f'{base_path}/graph_active_with_tags.pt')
torch.save(graph_recovered_with_tags, f'{base_path}/graph_recovered_with_tags.pt')
torch.save(joined_graph_with_tags, f'{base_path}/joined_graph_with_tags.pt')

In [27]:
print(f"Active graph nodes: {graph_active['user'].num_nodes}")
print(f"Recovered graph nodes: {graph_recovered['user'].num_nodes}")

print(f"Active graph user features: {graph_active_with_tags['user'].x.shape}")
print(f"Recovered graph user features: {graph_recovered_with_tags['user'].x.shape}")

print(f"Joined graph nodes: {joined_graph_with_tags['user'].num_nodes}")
print(f"Joined graph user features: {joined_graph_with_tags['user'].x.shape}")

print(f"Metapath food: {joined_graph_with_tags['user', 'metapath_food', 'user'].edge_index.shape}")
print(f"Metapath habit: {joined_graph_with_tags['user', 'metapath_habit', 'user'].edge_index.shape}")

Active graph nodes: 2413
Recovered graph nodes: 406
Active graph user features: torch.Size([2413, 4040])
Recovered graph user features: torch.Size([406, 893])
Joined graph nodes: 2819
Joined graph user features: torch.Size([2819, 4040])
Metapath food: torch.Size([2, 226229])
Metapath habit: torch.Size([2, 198620])
