Combine Scheduled data and Preference data 

In [9]:
import pandas as pd

# Load datasets
schedule_df = pd.read_csv('../../data/data.csv')
preference_df = pd.read_csv('../../data/datapreference.csv')

# Standardize column names for merging
if 'preferred_shift' in preference_df.columns:
    preference_df = preference_df.rename(columns={'preferred_shift': 'shift'})
if 'preferred_ward' in preference_df.columns:
    preference_df = preference_df.rename(columns={'preferred_ward': 'ward'})

# Label assignments: 1 for assigned (in schedule), 0 for not assigned (only in preferences)
schedule_df['label'] = 1

# Only keep preference rows NOT already assigned in schedule
merge_keys = ['nurse_id', 'date', 'shift', 'ward']
pref_unassigned = pd.merge(
    preference_df, 
    schedule_df[merge_keys], 
    on=merge_keys, 
    how='left', 
    indicator=True
)
pref_unassigned = pref_unassigned[pref_unassigned['_merge'] == 'left_only'].drop(columns=['_merge'])
pref_unassigned['label'] = 0

# Concatenate into one edge list
edge_df = pd.concat([schedule_df, pref_unassigned], ignore_index=True)

# Optional: engineer additional features for GAT here!
# Example: edge_df['is_preference'] = (edge_df['label'] == 0).astype(int)

# Check result
print(edge_df.head())
print("Total edges:", len(edge_df))
print(edge_df['label'].value_counts())

edge_df.to_csv('../../data/edges_for_gat.csv', index=False)


  nurse_id        date ward  shift  duration_hours  week       start_time  \
0     N008  2025-06-01    A  Night               8  22.0  0 days 00:00:00   
1     N030  2025-06-01    A  Flex1               4  22.0  0 days 08:00:00   
2     N004  2025-06-01    A  Flex2               4  22.0  0 days 12:00:00   
3     N008  2025-06-01    A  Flex3               4  22.0  0 days 16:00:00   
4     N003  2025-06-01    A  Flex4               4  22.0  0 days 20:00:00   

          end_time  label  
0  0 days 08:00:00      1  
1  0 days 12:00:00      1  
2  0 days 16:00:00      1  
3  0 days 20:00:00      1  
4  1 days 00:00:00      1  
Total edges: 3633
label
1    2800
0     833
Name: count, dtype: int64
