In [None]:
import json
import numpy as np
import random
import sys

## **Hypeparameters**

In [None]:
OVERTIME_P = 1.0  # willingness of nurses to work overtime
ABSENCE_P = 0.5   # probability of absence

In [4]:
OFF_SHIFTS = ["R", "A", "S"]

## **Nurse Class**

In [5]:
class Nurse():
    def __init__(self, id, index, assignments, contract_id, contract_details, shiftoffrequests):
        self.id = id      # ex: Nurse_1
        self.index = index       # ex: 0 (associated with place in JSON lists)
        self.assignments = assignments # list of shifts
        
        self.overtime_shifts = 0 # nobody has overtime for now
        self.overtime_willingness = 1 # always accept overtime TODO: Change this to use hyperparameter
        
        # contract details
        self.contract_id = contract_id # either Contract_1 or Contract_2
        self.contract_details = contract_details
        self.max_assignments = self.contract_details["MaximumNumberOfAssignments"]
        
        # assignments
        self.assignments_list = assignments
        self.num_assignments = len([x for x in self.assignments_list if x not in OFF_SHIFTS])
        self.remaining_available_days = self.max_assignments - self.num_assignments
        
        # shiftoffrequests
        self.shiftoffrequests = shiftoffrequests
    
    def sample_overtime(self):
        # TODO: CHANGE TO DISTRIBUTION LATER
        return self.overtime_willingness

## **Scheduling Simulation Environment**

In [6]:
class SchedulingEnv():
    def __init__(self, data_path, solution_path, OVERTIME_P = OVERTIME_P, ABSENCE_P = ABSENCE_P):
        """
        Properties of SchedulingEnv:
            self.data: data JSON
            self.solution: solution JSON
            self.NUM_DAYS: number of days in instance (int)
            self.NUM_NURSES: number of nurses in instance (int)
            self.shiftoffrequests: ???
            self.ALL_NURSEIDS_SET: set of all nurse ids - set(int)
            self.nurses_dict: {nurse id (string): nurse object (Nurse class)}
            self.nurse_status_dict: {status: set of nurse IDs per day (list)}
            self.init_state: NUM_DAYS x NUM_NURSES int matrix
            self.state: NUM_DAYS x NUM_NURSES int matrix
        """
        # Open JSONs
        with open(data_path) as f:
            self.data = json.load(f)
        with open(solution_path) as f:
            self.solution = json.load(f)

        self.NUM_DAYS = len(self.solution["Solution"][0]["Assignments"])
        self.NUM_NURSES = len(self.solution["Solution"])
        self.OVERTIME_P = OVERTIME_P
        self.ABSENCE_P = ABSENCE_P

        self.reset_env()
    
    def reset_env(self):
        self.init_shiftoffrequests()
        self.init_nurse_dict()
        self.init_nurse_status_dict()
        self.init_init_state()
    
    def init_shiftoffrequests(self):
        self.shiftoffrequests = dict()
        for request in self.data["Shiftoffrequests"]:
            if request["id"] not in self.shiftoffrequests:
                self.shiftoffrequests[request["id"]] = list()
            self.shiftoffrequests[request["id"]].append({request["day"]:request["shift"]})

    def parse_contract(self, contract_id):
        contract_details = {}
        if contract_id == "Contract_1":
            contract_details = self.data["Contracts"][0]
        elif contract_id == "Contract_2":
            contract_details = self.data["Contracts"][1]
        else:
            print("Invalid contract id.")
            sys.exit(1)
        return contract_details
    
    def init_nurse_dict(self):
        self.ALL_NURSEIDS_SET = set()
        self.nurses_dict = dict()
            
        # make nurse objects, stores shifts in list
        for index, nurse_sol in enumerate(self.solution["Solution"]):  
            nurse_id = nurse_sol["id"]
            assignments = [day["shift"] for day in nurse_sol["Assignments"]]
            contract_id = self.data["Nurses"][index]["contract_id"]
            contract_details = self.parse_contract(contract_id)
            if nurse_id in self.shiftoffrequests:
                nurse_requests = self.shiftoffrequests[nurse_id]
            else: nurse_requests = []
            
            # Add to nurse dictionary
            new_nurse = Nurse(nurse_id, index, assignments,contract_id, contract_details, nurse_requests)
            self.nurses_dict[nurse_id] = new_nurse
            
            # Add to nurse ID set
            self.ALL_NURSEIDS_SET.add(nurse_id) 

    def init_nurse_status_dict(self):
        # set nurse ID of working nurses per day
        working_nurses = [set([id for id, nurse in self.nurses_dict.items() if nurse.assignments[day] not in self.OFF_SHIFTS]) 
                        for day in range(self.NUM_DAYS)]
        absent_nurses = [set([id for id, nurse in self.nurses_dict.items() if nurse.assignments[day] == "A"]) 
                        for day in range(self.NUM_DAYS)]
        standby_nurses = [set([id for id, nurse in self.nurses_dict.items() if nurse.assignments[day] == "S"]) 
                        for day in range(self.NUM_DAYS)]
        resting_nurses = [set([id for id, nurse in self.nurses_dict.items() if nurse.assignments[day] == "R" and nurse.remaining_available_days > 0]) 
                        for day in range(self.NUM_DAYS)]
        overtime_nurses = [set([id for id, nurse in self.nurses_dict.items() if nurse.assignments[day] == "R" and nurse.remaining_available_days == 0]) 
                        for day in range(self.NUM_DAYS)]

        # {nurse status: list of nurse ID sets (index = day)}
        self.nurse_status_dict = dict()
        self.nurse_status_dict["working"] = working_nurses
        self.nurse_status_dict["absent"] = absent_nurses
        self.nurse_status_dict["standby"] = standby_nurses
        self.nurse_status_dict["resting"] = resting_nurses
        self.nurse_status_dict["overtime"] = overtime_nurses
    
    def init_init_state(self):
        self.init_state = np.zeros((self.NUM_DAYS, self.NUM_NURSES))
        for status, days in self.nurse_status_dict.items():
            for day, nurses in enumerate(days):
                for nurse_id in nurses:
                    nurse = self.nurses_dict[nurse_id]
                    if status == "absent":
                        self.init_state[day][nurse.index] = 0
                    if status == "working":
                        self.init_state[day][nurse.index] = 1
                    if status == "standby":
                        self.init_state[day][nurse.index] = 2
                    if status == "resting":
                        self.init_state[day][nurse.index] = 3
                    if status == "overtime":
                        self.init_state[day][nurse.index] = 4

In [1]:
data_path = 'data/W1-01.json'
solution_path = 'solutions/sol-W1-01.json'

## **Reinforcement Learning Model**

In [None]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env

In [None]:
class RlSchedEnv(SchedulingEnv):
    def __init__(self, data_path, solution_path):
        super().__init__(data_path, solution_path)

        # Constants
        self.NUM_ACTIONS = 3 # 0: swap, 1: overtime, 2: no change
        self.NUM_STATUSES = 5 # 0: absent, 1: working, 2: standby, 3: resting, 4: overtime
        self.rewards = {
            "understaffed": -5,
            "overtime": -2,
            "change": -1,
            "standby": 5,
            "resting": 3,
        }

        # Action and observation space
        self.action_space = spaces.Discrete(self.NUM_ACTIONS)
        self.observation_space = spaces.Box(low=0, high=self.NUM_STATUSES-1,
                                            shape=(self.NUM_DAYS, self.NUM_NURSES), 
                                            dtype=np.uint32)    

        self.reset()

    def step(self, action):
        self.handle_absence()

        # Attempt to swap shifts
        if action == 0:
            self.actions.append("Swap")
            self.reward += self.rewards["change"]
            self.swap()
        # Request overtime
        elif action == 1:
            self.actions.append("Overtime")
            self.reward += self.rewards["change"]
            self.overtime()
            pass
        # No change
        elif action == 2:
            self.actions.append("Unchanged")
            self.reward += self.rewards["understaffed"]

        # Increment time step
        self.timestep += 1
        if self.timestep >= self.NUM_DAYS:
            self.done = True
            print(self.actions)

        return self.observation, self.reward, self.done, {}

    def handle_absence(self):
        """
        Simulates a potential absence with probability `ABSENCE_P`.
        If an absence occurs, randomly selects one nurse to be absent
        for the current day and updates the state.
        """
        if random.random() < self.ABSENCE_P:
            self.abs_nurse_id = random.choice(list(self.nurse_status_dict["working"]))

            # Update state
            self.nurse_status_dict["working"].remove(self.abs_nurse_id)
            self.nurse_status_dict["absent"].add(self.abs_nurse_id)
            nurse = self.nurses_dict[self.abs_nurse_id]
            nurse.assignments[self.timestep] = 0
            nurse.num_assignments -= 1
            nurse.remaining_available_days += 1
            self.state[self.timestep, self.nurses_dict[self.abs_nurse_id].index] = 0 
        else:
            self.abs_nurse_id = None

    def swap(self):
        if self.nurse_status_dict["standby"]:
            self.actions.append("Standby Swap")
            self.rewards += self.rewards["standby"]

            # TODO: FIGURE THIS OUT JFKLDSJFKLSDJFLK update state
            rep_nurse_id = random.choice(list(self.nurse_status_dict["standby"]))
            self.nurse_status_dict["standby"].remove(self.rep_nurse_id)
            self.nurse_status_dict["working"].remove(self.rep_nurse_id)

            rep_nurse = self.nurses_dict(rep_nurse_id)

            abs_nurse = self.nurses_dict(self.abs_nurse_id)
            pass
        elif self.ALL_NURSEIDS_SET["resting"]:
            self.actions.append("Resting Swap")
            self.rewards += self.rewards["resting"]
            pass
        else:
            self.actions.append("Failed Swap")
            self.rewards += self.rewards["understaffed"]
            pass

    def overtime(self):
        if self.ALL_NURSEIDS_SET["overtime"] and np.random.rand() < self.OVERTIME_P:
            self.actions.append("Successful Overtime")
            self.rewards += self.rewards["overtime"]
        else:
            self.actions.append("Failed Overtime")
            self.rewards += self.rewards["understaffed"]

    def reset(self):
        self.reset_env()
        self.state = self.init_state.copy()
        self.done = False
        self.timestep = 0
        self.actions = []
        self.abs_nurse_id = None
        return self.state

In [None]:
env = RlSchedEnv(data_path, solution_path)
check_env(env)
RL_model = A2C("MlpPolicy", env).learn(total_timesteps=env.NUM_DAYS)

## **Greedy Heuristic Model**

In [None]:
class GreedySchedEnv(SchedulingEnv):
    def __init__(self, data_path, solution_path):
        super().__init__(data_path, solution_path)
        # TODO: modify class

____________________

In [None]:

# Assuming we are dealing with shifts for a given time period
for shift in shifts:
    shift_id = shift["ShiftID"]
    for nurse in nurses:
        # Check if the nurse is available to work this shift
        available = True
        
        # Check for shift off requests
        if shift_id in [s["ShiftID"] for s in shift_off_requests.get(nurse["ID"], [])]:
            available = False
        
        # Check for contract constraints (max shifts per period)
        if available:
            contract = next(c for c in contracts if c["NurseID"] == nurse["ID"])
            # Here we just check if nurse has already maxed out shifts
            scheduled_shifts = sum(1 for s in solutions if s["NurseID"] == nurse["ID"] and s["ShiftID"] == shift_id)
            if scheduled_shifts >= contract["MaxShifts"]:
                available = False
        
        if available:
            nurse_availability[nurse["ID"]].append(shift_id)

# State representation
def get_state(t):
    # Create the state for time step t (current state of the scheduling problem)
    
    state = {}
    # Current schedule: list of assignments for all nurses, all shifts
    state['current_schedule'] = solutions  # The solution is already the initial assignment.
    
    # Absentee status: which nurses are absent and for which shifts
    state['absent_status'] = {nurse["ID"]: nurse["Absences"] for nurse in nurses}
    
    # Nurse availability: which nurses are available to take on additional shifts
    state['nurse_availability'] = nurse_availability
    
    # Overtime willingness: modeled as a probability distribution (using Gaussian as a stand-in)
    state['overtime_willingness'] = {nurse["ID"]: random.gauss(overtime_willingness_p, 0.1) for nurse in nurses}

    # Staffing levels (understaffing calculation)
    state['staffing_level'] = sum(1 for s in solutions if s["ShiftID"] == t and s["NurseID"] != 'vacant')  # Assume 'vacant' means no nurse assigned
    
    # Perturbation from the original schedule
    state['perturbation'] = sum(1 for s in solutions if s["ShiftID"] == t and s["NurseID"] != s["original_nurse_assignment"])  # Example, assume we track original assignments

    return state

# Reward function (using a stand-in for now)
def compute_reward(state, action):
    reward = 0
    
    # Perturbation penalty (example: if a lot of changes were made, penalize)
    reward -= state['perturbation']
    
    # Overtime cost: apply a penalty for assigning overtime hours
    overtime_cost = sum(1 for nurse_id, willingness in state['overtime_willingness'].items() if willingness > 0.5)  # Example: high willingness means higher cost
    reward -= overtime_cost
    
    # Understaffing penalty: penalize if staffing level is too low
    if state['staffing_level'] < len(shifts) // 2:  # Example: staffing below 50% of total shifts
        reward -= 10
    
    return reward

# Action space
def take_action(state, action):
    # Example action: swap nurse assignments or request overtime
    if action['type'] == 'swap':
        # Swap assignments between nurses
        nurse1 = action['nurse1']
        nurse2 = action['nurse2']
        
        # Swap their shifts (you'll need logic to verify validity based on constraints)
        for shift in shifts:
            if shift["ShiftID"] in state['nurse_availability'][nurse1] and shift["ShiftID"] in state['nurse_availability'][nurse2]:
                # Swap the assignments in the solutions
                for s in solutions:
                    if s["NurseID"] == nurse1 and s["ShiftID"] == shift["ShiftID"]:
                        s["NurseID"] = nurse2
                    elif s["NurseID"] == nurse2 and s["ShiftID"] == shift["ShiftID"]:
                        s["NurseID"] = nurse1
        return state, compute_reward(state, action)
    
    elif action['type'] == 'overtime':
        # Assign overtime to an available nurse (probabilistic, based on willingness)
        overtime_nurse = action['nurse']
        overtime_shift = action['shift']
        
        # Check overtime willingness
        willingness = state['overtime_willingness'][overtime_nurse]
        if random.random() < willingness:
            # Assign the overtime shift
            for s in solutions:
                if s["NurseID"] == 'vacant' and s["ShiftID"] == overtime_shift:
                    s["NurseID"] = overtime_nurse
            return state, compute_reward(state, action)
        else:
            return state, -10  # Penalty for failed overtime assignment
    
    else:
        # No change, just return the same state
        return state, -1  # Understaffing penalty

# Example of how the state would evolve based on the above setup
t = 1  # Example time step
state = get_state(t)

# Take an action (swap nurse assignments or assign overtime)
action = {'type': 'swap', 'nurse1': 'nurse_1', 'nurse2': 'nurse_2'}
state, reward = take_action(state, action)
print(state, reward)


In [None]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

# Create environment and wrap it in DummyVecEnv for training
env = SchedulingEnv(data_path='data/W1-01.json', solution_path='solution/sol-W1-01.json', overtime_p=1.0)
env = DummyVecEnv([lambda: env])

# Initialize PPO model
model = PPO("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=10000)

# Save the model after training
model.save("scheduling_ppo_model")


In [None]:
# Load the trained model
model = PPO.load("scheduling_ppo_model")

# Test the model
obs = env.reset()
for _ in range(100):
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    print(f"Action: {action}, Reward: {rewards}")
    if done:
        break


In [None]:

import gym
from gym import spaces
import numpy as np
import pandas as pd
import random
import json

class SchedulingEnv(gym.Env):
    def __init__(self, data_path, solution_path, overtime_p=1.0):
        super(SchedulingEnv, self).__init__()

        # Load the data and solution
        with open(data_path, 'r') as f:
            self.data = json.load(f)
        
        with open(solution_path, 'r') as f:
            self.solution = json.load(f)
        
        # Hyperparameter
        self.overtime_p = overtime_p  # Probability of a nurse accepting overtime
        
        # Initialize states
        self.current_schedule = self.solution['schedule']
        self.nurses = self.data['nurses']
        self.shifts = self.data['shifts']
        self.time_steps = len(self.shifts)  # Assuming each shift corresponds to a time step
        self.current_step = 0

        # Action space: 0 = no action (unchanged), 1 = swap, 2 = overtime
        self.action_space = spaces.Discrete(3)

        # Observation space: Here we keep the number of nurses, shifts, etc.
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(len(self.nurses), len(self.shifts)), dtype=np.int32
        )

    def reset(self):
        self.current_schedule = self.solution['schedule']  # reset to the initial schedule
        self.current_step = 0
        return self._get_obs()

    def _get_obs(self):
        # Create the current state representation
        state = np.zeros((len(self.nurses), len(self.shifts)), dtype=np.int32)

        # Fill the state matrix with current assignments (1 = assigned, 0 = not assigned)
        for nurse_idx, assignments in enumerate(self.current_schedule):
            for shift in assignments:
                state[nurse_idx, shift] = 1
        return state.flatten()

    def step(self, action):
        reward = 0
        done = False

        # Action Handling
        if action == 0:
            reward = -1  # No change (penalty for not addressing absence)
        elif action == 1:
            reward = self._perform_swap()
        elif action == 2:
            reward = self._offer_overtime()

        # Increment time step
        self.current_step += 1
        if self.current_step >= self.time_steps:
            done = True  # end of schedule window

        return self._get_obs(), reward, done, {}

    def _perform_swap(self):
        # Randomly select a nurse to swap assignments (simplified)
        nurse_1, nurse_2 = random.sample(range(len(self.nurses)), 2)
        shift_1 = random.choice(self.current_schedule[nurse_1])
        shift_2 = random.choice(self.current_schedule[nurse_2])

        # Swap the shifts
        self.current_schedule[nurse_1].remove(shift_1)
        self.current_schedule[nurse_1].append(shift_2)
        self.current_schedule[nurse_2].remove(shift_2)
        self.current_schedule[nurse_2].append(shift_1)

        return -1  # No reward for the swap itself

    def _offer_overtime(self):
        # Randomly offer overtime to a nurse
        overtime_acceptance = np.random.rand() < self.overtime_p

        if overtime_acceptance:
            return -2  # Overtime penalty (cost)
        else:
            return -3  # Understaffing penalty (cost)
