<a href="https://colab.research.google.com/github/SiddhiMarri/RosterScheduling/blob/main/DQN_witharrayloss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Assumptions:

Ratings==positions

Holidays is user input

Controller requirements class is also user input

recency is calculated separately

ignoring training schedule

In [77]:
import numpy as np
import pandas as pd

POSITIONS = {0: 'APP', 1: 'CWL DIR', 2: 'BKN DIR', 3: 'CWL', 4: 'BKH DEPS', 5: 'ADC', 6: 'GND'}
CONTROLLERS = {
    1: 'C', 2: 'D', 3: 'F', 4: 'G', 5: 'J', 6: 'K', 7: 'L', 8: 'N', 9: 'P',
    10: 'R', 11: 'S', 12: 'T', 13: 'V', 14: 'X', 15: 'Y', 16: 'O', 17: 'H',
    18: 'Z', 19: 'B', 20: 'E', 21: 'M', 22: 'Q', 23: 'W'
}

SHIFTS = {
    'MORNING': {'start': '07:30', 'end': '13:30', 'duration': 6},
    'AFTERNOON': {'start': '13:30', 'end': '20:00', 'duration': 6.5},
    'NIGHT': {'start': '20:00', 'end': '07:30', 'duration': 11},
    'NIGHT_OFF': None,
    'CLEAR_OFF': None
}

MAX_RATINGS_PER_CONTROLLER = 7
class ControllerRequirements:
    MEDICAL_FITNESS_REQUIRED = True
    ENGLISH_PROFICIENCY_REQUIRED = True
    recency = [True] * 7

# Duty constraints
DUTY_LIMITS = {
    'CONTINUOUS_DUTY_MAX': 12,
    'WEEKLY_DUTY_MAX': 48,
    'MONTHLY_DUTY_MAX': 190,
    'CONSECUTIVE_DAYS_MAX': 6,
    'MANDATORY_BREAK_AFTER_6_DAYS': 48,
    'BREAK_AFTER_NIGHT_DUTY': 48,
    'MIN_BREAK_BETWEEN_SHIFTS': 12,
}

class StaffingRules:
    POSITIONS_TO_PEOPLE_RATIO = {
        2: 3
    }
    BREAK_POLICIES = {
        'STANDARD': {'duty': 2, 'break': 0.5},  # 2 hours duty, 0.5 hours break
        'CURRENT': {'duty': 1.5, 'break': 0.75}  # 1.5 hours duty, 0.75 hours break
    }

In [78]:
controller_requirements = {
    1: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, False, True, True]},
    2: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, False]},
    3: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, False, True, True, True]},
    4: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, False, True, True]},
    5: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, False, True]},
    6: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    7: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, True, False, True, False]},
    8: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, True]},
    9: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, False, True, True, False, True]},
    10: {'medical_valid': True, 'english_valid': False, 'recency': [True, False, True, True, True, True, True]},
    11: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, False, True, True, False]},
    12: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, True]},
    13: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, True, False, True, True]},
    14: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, True, False, True]},
    15: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    16: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, False]},
    17: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, False, True, True, True]},
    18: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, True, False, True, True]},
    19: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, False, True]},
    20: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    21: {'medical_valid': False, 'english_valid': True, 'recency': [True, False, True, True, True, True, False]},
    22: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, True, True, False, True]},
    23: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, True, False, True]}
}

In [79]:
import random

POSITIONS = {
    0: 'APP', 1: 'CWL DIR', 2: 'BKN DIR', 3: 'CWL',
    4: 'BKH DEPS', 5: 'ADC', 6: 'GND'
}
CONTROLLERS = list(range(1, 24))
SLOTS_PER_DAY = 16
DAYS = 7
NUM_POSITIONS = len(POSITIONS)
EMPTY_SLOT_PROBABILITY = 0.2
roster = []

for day in range(DAYS):
    day_slots = []
    for position in range(NUM_POSITIONS):
        slots = []
        for _ in range(SLOTS_PER_DAY):
            if random.random() < EMPTY_SLOT_PROBABILITY:
                slots.append(None)
            else:
                slots.append(random.choice(CONTROLLERS))
        day_slots.append(slots)
    roster.append(day_slots)

# Print the roster
for day in range(DAYS):
    print(f"Day {day + 1}:")
    for position, slots in enumerate(roster[day]):
        slot_str = ', '.join(str(c) if c is not None else 'Empty' for c in slots)
        print(f"  {POSITIONS[position]}: {slot_str}")
    print()


Day 1:
  APP: 14, Empty, 14, Empty, 6, 6, 19, 13, 5, Empty, 21, 1, Empty, 5, Empty, 14
  CWL DIR: 22, 17, 15, Empty, 20, 23, 17, 16, 8, 13, 1, 17, 20, 16, 21, 23
  BKN DIR: 20, 11, 7, 12, Empty, 18, 16, 12, 15, 12, 1, 9, Empty, 3, 21, 6
  CWL: Empty, 7, 12, 8, 4, 21, 20, Empty, 9, 8, 1, 19, 14, 2, 14, 12
  BKH DEPS: 10, 14, 19, 20, 22, 20, 2, 17, 21, 6, Empty, Empty, Empty, 13, Empty, 21
  ADC: 13, 16, Empty, 5, Empty, 15, Empty, 2, 22, 17, 6, 7, 2, 9, 3, 4
  GND: 17, 4, 8, 18, 2, 1, 20, 3, 13, 1, 16, 12, Empty, 20, 10, 1

Day 2:
  APP: 13, 4, 14, 21, 15, 9, 11, 18, 4, 7, 20, 7, 11, 7, 5, 11
  CWL DIR: Empty, Empty, Empty, 9, 10, 20, 18, 10, Empty, Empty, 22, 16, 5, 16, 11, 23
  BKN DIR: 2, 15, 10, Empty, 18, 3, Empty, 22, 5, 20, 4, Empty, Empty, 1, 13, Empty
  CWL: Empty, 22, Empty, 17, 5, 11, 1, 2, 8, 23, Empty, 15, 7, 1, 23, 19
  BKH DEPS: 17, 15, 7, 2, Empty, 4, Empty, Empty, 2, 11, 10, 14, 6, Empty, 6, 19
  ADC: 5, 11, 7, 13, 2, 15, 1, 5, 15, 21, 13, 23, 5, 14, 18, 22
  GND: 9, 18

In [80]:
def calculate_roster_violations(roster, shifts, duty_limits, staffing_rules, controller_requirements):
    violations = {
        "medical_invalid": 0,
        "english_invalid": 0,
        "recency_violation": 0,
        "continuous_duty_violation": 0,
        "weekly_duty_violation": 0,
        "monthly_duty_violation": 0,
        "mandatory_break_violation": 0,
        "break_after_night_violation": 0,
        "empty_slot_violation": 0,
    }

    controller_duty_hours = {c: 0 for c in controller_requirements.keys()}
    controller_consecutive_days = {c: 0 for c in controller_requirements.keys()}
    last_shift_end = {c: None for c in controller_requirements.keys()}
    active_shifts = {k: v for k, v in shifts.items() if v is not None}

    for day_idx, day_schedule in enumerate(roster):
        daily_duty = {c: 0 for c in controller_requirements.keys()}

        for position_idx, slots in enumerate(day_schedule):
            for slot_idx, controller in enumerate(slots):
                if controller is None:
                    violations["empty_slot_violation"] += 1
                    continue

                # Check controller validity
                if not controller_requirements[controller]['medical_valid']:
                    violations["medical_invalid"] += 1
                if not controller_requirements[controller]['english_valid']:
                    violations["english_invalid"] += 1
                if not controller_requirements[controller]['recency'][position_idx]:
                    violations["recency_violation"] += 1

                # Determine shift and duration
                shift_idx = slot_idx % len(active_shifts)
                shift = list(active_shifts.values())[shift_idx]
                duration = shift['duration']
                daily_duty[controller] += duration
                controller_duty_hours[controller] += duration

                # Check for continuous duty violations
                if daily_duty[controller] > duty_limits['CONTINUOUS_DUTY_MAX']:
                    violations["continuous_duty_violation"] += 1

                # Check for break violations between shifts
                if last_shift_end[controller] is not None:
                    hours_since_last_shift = (day_idx * 24) - last_shift_end[controller]
                    if hours_since_last_shift < duty_limits['MIN_BREAK_BETWEEN_SHIFTS']:
                        violations["mandatory_break_violation"] += 1

                last_shift_end[controller] = (day_idx * 24) + shift['duration']

        # After processing all positions for a day, check for weekly/monthly duty violations
        for controller, hours in daily_duty.items():
            if controller_consecutive_days[controller] >= duty_limits['CONSECUTIVE_DAYS_MAX']:
                violations["mandatory_break_violation"] += 1

            # Reset consecutive days if no hours for this controller
            controller_consecutive_days[controller] = controller_consecutive_days[controller] + 1 if hours > 0 else 0
            if controller_duty_hours[controller] > duty_limits['WEEKLY_DUTY_MAX']:
                violations["weekly_duty_violation"] += 1
            if controller_duty_hours[controller] > duty_limits['MONTHLY_DUTY_MAX']:
                violations["monthly_duty_violation"] += 1

        # Check for break after night duty violations
        for controller in controller_requirements:
            if last_shift_end[controller] and last_shift_end[controller] % 24 == 0:
                hours_since_night = day_idx * 24 - last_shift_end[controller]
                if hours_since_night < duty_limits['BREAK_AFTER_NIGHT_DUTY']:
                    violations["break_after_night_violation"] += 1

    return violations

# Call the function with the updated roster
violations = calculate_roster_violations(
    roster=roster,
    shifts=SHIFTS,
    duty_limits=DUTY_LIMITS,
    staffing_rules=StaffingRules,
    controller_requirements=controller_requirements
)

# Print the violations
for violation_type, count in violations.items():
    print(f"{violation_type}: {count}")


medical_invalid: 137
english_invalid: 148
recency_violation: 121
continuous_duty_violation: 460
weekly_duty_violation: 137
monthly_duty_violation: 26
mandatory_break_violation: 496
break_after_night_violation: 0
empty_slot_violation: 151


In [81]:
import numpy as np
NUM_POSITIONS = 7
NUM_CONTROLLERS = 23
NUM_DAYS = 7
NUM_SHIFTS_PER_DAY = 16
REWARDS = {
    "medical_invalid": -100,
    "english_invalid": -80,
    "recency_violation": -50,
    "continuous_duty_violation": -200,
    "weekly_duty_violation": -150,
    "monthly_duty_violation": -100,
    "mandatory_break_violation": -250,
    "break_after_night_violation": -120,
    "empty_slot_violation": -50,
    "good_shift": 200,
    "low_violations_day": 50,
    "extra_duty_completed": 100,
}



In [82]:
from collections import deque


In [83]:
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def push(self, transition):
        self.memory.append(transition)
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [84]:
# prompt: print the size of 3D roster

print(f"The size of the 3D roster is: {len(roster)} days x {len(roster[0])} positions x {len(roster[0][0])} slots")

The size of the 3D roster is: 7 days x 7 positions x 16 slots


In [87]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import matplotlib.pyplot as plt

# Define the DQN Network
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()

        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)
        self.fc5 = nn.Linear(64, output_dim)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.dropout(x)
        x = torch.relu(self.fc4(x))
        return self.fc5(x)

# Define the EmployeeSchedulingEnv class
class EmployeeSchedulingEnv:
    def __init__(self, initial_roster):
        self.initial_roster = initial_roster
        self.state = self.reset()
        self.n_positions = len(initial_roster[0])
        self.n_slots = len(initial_roster[0][0])
        self.n_days =len(initial_roster)

    def reset(self):
        self.state = np.copy(self.initial_roster)
        return self.get_state_representation()

    def get_state_representation(self):
        controller_map = {i: i + 1 for i in range(23)}  # Controller mapping (0-22 to 1-23)
        numerical_state = []

        # Adjust the iteration based on the correct shape (7, 7, 16)
        for day in range(7):  # Loop over days
            for position in range(7):  # Loop over positions
                for slot in range(16):  # Loop over slots
                    controller = self.state[day, position, slot]  # Access the state correctly
                    numerical_state.append(controller_map.get(controller, -1))  # -1 for unassigned
        print(numerical_state)

        return np.array(numerical_state)

    def check_if_done(self):
        return np.sum(self.state != -1) == self.n_positions * self.n_slots * self.n_days

    def step(self, action):
        position, slot, controller, day = action

        # Check if controller is already assigned too many times
        if np.sum(self.state[:, :, :] == controller + 1) >= MAX_RATINGS_PER_CONTROLLER:
            return self.get_state_representation(), -100, False, self.state

        # Assign controller to the position
        self.state[position, slot, day] = controller + 1

        # Calculate the roster violations (penalties)
        metrics = calculate_roster_violations(
            self.state, shifts=SHIFTS, duty_limits=DUTY_LIMITS, staffing_rules=StaffingRules.POSITIONS_TO_PEOPLE_RATIO,
            controller_requirements=controller_requirements
        )

        # Calculate reward
        reward = sum(REWARDS.get(violation, 0) * count for violation, count in metrics.items())

        # Penalty for consecutive slots
        if slot > 0 and self.state[position, slot - 1, day] == controller + 1:
            penalty = -50
        else:
            penalty = 0

        reward += penalty

        # Check if the schedule is complete (done)
        done = self.check_if_done()

        # If the schedule is done but has unfilled slots, apply penalty
        if done and not np.all(self.state != 0):
            unfilled_penalty = -100
            reward += unfilled_penalty

        return self.get_state_representation(), reward, done, metrics

# Loss calculation function for training
# Loss calculation function for training
def compute_action_index(position, slot, controller, day, n_positions=7, n_slots=16, n_controllers=23, n_days=7):
    # Flatten the (position, slot, controller, day) tuple to a single index
    return position * (n_slots * n_controllers * n_days) + slot * (n_controllers * n_days) + controller * n_days + day

def calculate_loss(metrics,q_values, next_q_values, reward_batch, action_batch):
    # Compute the index from the action_batch
    action_indices = [compute_action_index(*action) for action in action_batch]

    # Reshape the action indices to be compatible with the output dimension of q_values
    action_indices = torch.tensor(action_indices, dtype=torch.long).view(-1, 1)

    # Gather the Q-values for the selected actions
    q_value = q_values.gather(1, action_indices)  # action_indices should be of shape (batch_size, 1)

    # Compute the expected Q-value (target)
    next_q_value = next_q_values.max(1)[0].unsqueeze(1)
    expected_q_value = reward_batch.unsqueeze(1) + gamma * next_q_value

    # Compute the MSE loss between the predicted and expected Q-values
    return nn.MSELoss()(q_value, expected_q_value)



# Define Replay Memory class
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, transition):
        if len(self.memory) < self.capacity:
            self.memory.append(transition)
        else:
            self.memory[self.position] = transition
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

# Training Loop
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
batch_size = 64
target_update_freq = 5
memory_capacity = 100
n_episodes = 100
epsilon_decay = 0.89
learning_rate = 5e-4
env = EmployeeSchedulingEnv(roster)
controllers = 23

input_dim = 7 * 7* 16
output_dim = 7 *7*16
dqn = DQN(input_dim, output_dim)
target_dqn = DQN(input_dim, output_dim)
target_dqn.load_state_dict(dqn.state_dict())
optimizer = optim.Adam(dqn.parameters(), lr=learning_rate)
memory = ReplayMemory(memory_capacity)

rewardsper_epoch = []
losses = []
for episode in range(n_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(16):
        if random.random() < epsilon:
            action = (random.randint(0, 6), t, random.randint(0, 22), random.randint(0, 6))  # Random action
        else:
            with torch.no_grad():
                q_values = dqn(torch.tensor(state.flatten(), dtype=torch.float32))
                action_idx = q_values.argmax().item()

                # Validate the action index calculation
                position = action_idx // (controllers * env.n_days)
                controller = (action_idx // env.n_days) % controllers
                day = action_idx % env.n_days
                slot = t  # or another way to choose the slot, depending on the task

                # Ensure position, slot, controller, and day are within valid ranges
                assert 0 <= position < env.n_positions
                assert 0 <= slot < env.n_slots
                assert 0 <= controller < controllers
                assert 0 <= day < env.n_days

                action = (position, slot, controller, day)

        next_state, reward, done, metrics = env.step(action)

        memory.push((state, action, reward, next_state))
        state = next_state
        total_reward += reward

        if done:
            break

        if len(memory.memory) >= batch_size:
            transitions = memory.sample(batch_size)
            batch = list(zip(*transitions))

            state_batch = torch.tensor(np.array(batch[0]), dtype=torch.float32).view(batch_size, -1)
            action_batch = torch.tensor(np.array(batch[1]), dtype=torch.long)  # action_batch should be a 1D tensor of indices
            reward_batch = torch.tensor(np.array(batch[2]), dtype=torch.float32)
            next_state_batch = torch.tensor(np.array(batch[3]), dtype=torch.float32).view(batch_size, -1)

            # Calculate loss
            loss = calculate_loss(metrics, dqn(state_batch), target_dqn(next_state_batch), reward_batch, action_batch)
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

    if episode % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())

    print(f"Episode {episode + 1}/{n_episodes}, Total Reward: {total_reward}")
    rewardsper_epoch.append(total_reward)

# Plot the loss curve
plt.plot(losses)
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.title('Loss Trend During Training')
plt.show()


[15, -1, 15, -1, 7, 7, 20, 14, 6, -1, 22, 2, -1, 6, -1, 15, 23, 18, 16, -1, 21, -1, 18, 17, 9, 14, 2, 18, 21, 17, 22, -1, 21, 12, 8, 13, -1, 19, 17, 13, 16, 13, 2, 10, -1, 4, 22, 7, -1, 8, 13, 9, 5, 22, 21, -1, 10, 9, 2, 20, 15, 3, 15, 13, 11, 15, 20, 21, 23, 21, 3, 18, 22, 7, -1, -1, -1, 14, -1, 22, 14, 17, -1, 6, -1, 16, -1, 3, 23, 18, 7, 8, 3, 10, 4, 5, 18, 5, 9, 19, 3, 2, 21, 4, 14, 2, 17, 13, -1, 21, 11, 2, 14, 5, 15, 22, 16, 10, 12, 19, 5, 8, 21, 8, 12, 8, 6, 12, -1, -1, -1, 10, 11, 21, 19, 11, -1, -1, 23, 17, 6, 17, 12, -1, 3, 16, 11, -1, 19, 4, -1, 23, 6, 21, 5, -1, -1, 2, 14, -1, -1, 23, -1, 18, 6, 12, 2, 3, 9, -1, -1, 16, 8, 2, -1, 20, 18, 16, 8, 3, -1, 5, -1, -1, 3, 12, 11, 15, 7, -1, 7, 20, 6, 12, 8, 14, 3, 16, 2, 6, 16, 22, 14, -1, 6, 15, 19, 23, 10, 19, 20, 6, 12, -1, 5, 21, 19, 7, 18, 16, 22, 18, -1, 2, 2, -1, -1, 17, 10, 23, 12, -1, 7, 21, 20, 20, 4, 21, 21, -1, 9, 3, -1, 2, 8, 21, 21, 12, 20, 5, 5, -1, 18, 19, 18, 9, -1, 7, -1, 14, 13, 7, -1, 18, 4, 16, 21, 19, 18, 11,

RuntimeError: index 1492 is out of bounds for dimension 1 with size 784

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)
positions = ['APP', 'CWL_DIR', 'BKN_DIR', 'CWL', 'BKH_DEPS', 'ADC', 'GND']
time_slots = [f"S{i+1}" for i in range(NUM_SLOTS)]
def print_learned_roster(dqn, env, controllers):
    weekly_schedule = []
    state = env.reset()
    learned_roster = np.zeros_like(env.initial_roster, dtype=int)
    for day in range(env.n_days):
        for slot in range(env.n_slots):
            state_tensor = torch.tensor(state.flatten(), dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                q_values = dqn(state_tensor)
            action_idx = q_values.argmax().item()
            position = action_idx // (len(controllers) * env.n_days)
            controller = (action_idx // env.n_days) % len(controllers)
            day_of_week = action_idx % env.n_days
            learned_roster[position, slot, day_of_week] = controller + 1
            next_state, _, _, _ = env.step((position, slot, controller, day_of_week))
            state = next_state

    for day in range(env.n_days):
        daily_schedule = pd.DataFrame(learned_roster[day].T, index=positions, columns=time_slots)
        weekly_schedule.append(daily_schedule)
    for day, schedule in enumerate(weekly_schedule, start=1):
        print(f"Day {day} Schedule:")
        print(schedule)
        print("\n" + "="*50 + "\n")

print_learned_roster(dqn, env, controllers)
