<a href="https://colab.research.google.com/github/SiddhiMarri/RosterScheduling/blob/main/DQN_witharrayloss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Assumptions:

Ratings==positions

Holidays is user input

Controller requirements class is also user input

recency is calculated separately

ignoring training schedule

In [None]:
import numpy as np
import pandas as pd

# Positions and controllers mapping
POSITIONS = {0: 'APP', 1: 'CWL DIR', 2: 'BKN DIR', 3: 'CWL', 4: 'BKH DEPS', 5: 'ADC', 6: 'GND'}
CONTROLLERS = {
    0: 'C', 1: 'D', 2: 'F', 3: 'G', 4: 'J', 5: 'K', 6: 'L', 7: 'N', 8: 'P',
    9: 'R', 10: 'S', 11: 'T', 12: 'V', 13: 'X', 14: 'Y', 15: 'O', 16: 'H',
    17: 'Z', 18: 'B', 19: 'E', 20: 'M', 21: 'Q', 22: 'W'
}
SHIFTS = {
    'MORNING': {'start': '07:30', 'end': '13:30', 'duration': 6},
    'AFTERNOON': {'start': '13:30', 'end': '20:00', 'duration': 6.5},
    'NIGHT': {'start': '20:00', 'end': '07:30', 'duration': 11},
    'NIGHT_OFF': None,
    'CLEAR_OFF': None
}
TEAMS = ['A', 'B', 'C', 'D', 'E']
MAX_RATINGS_PER_CONTROLLER = 7
class ControllerRequirements:
    MEDICAL_FITNESS_REQUIRED = True
    ENGLISH_PROFICIENCY_REQUIRED = True
    recency = [True] * 10

# Duty constraints
DUTY_LIMITS = {
    'CONTINUOUS_DUTY_MAX': 12,      # hours
    'WEEKLY_DUTY_MAX': 48,          # hours in 7 days
    'MONTHLY_DUTY_MAX': 190,        # hours in 30 days
    'CONSECUTIVE_DAYS_MAX': 6,      # days
    'MANDATORY_BREAK_AFTER_6_DAYS': 48,  # hours
    'BREAK_AFTER_NIGHT_DUTY': 48,   # hours
    'MIN_BREAK_BETWEEN_SHIFTS': 12,  # hours
}

# Staffing rules
class StaffingRules:
    POSITIONS_TO_PEOPLE_RATIO = {
        2: 3  # For every 2 positions, 3 people are needed
    }
    BREAK_POLICIES = {
        'STANDARD': {'duty': 2, 'break': 0.5},  # 2 hours duty, 0.5 hours break
        'CURRENT': {'duty': 1.5, 'break': 0.75}  # 1.5 hours duty, 0.75 hours break
    }

# Negative rewards for violations
NEGATIVE_REWARDS = {
    'CONTINUOUS_12_HOURS': -100,   # For exceeding 12 continuous hours
    'EXTRA_DUTY': -50,            # For assigning extras
    'CYCLE_DISRUPTION': -30,      # For disrupting the team's cycle
    'CONSECUTIVE_NIGHTS': -40     # For consecutive night duties
}

# Team cycle
TEAM_CYCLE = [
    'MORNING',
    'AFTERNOON',
    'NIGHT',
    'NIGHT_OFF',
    'CLEAR_OFF'
]

# Example of a dynamic controller assignment tracker
controller_schedule = {
    controller: {
        'team': None,
        'assigned_shifts': [],  # List of shifts assigned
        'total_hours': 0,       # Total hours worked
        'violations': []        # List of any rule violations
    } for controller in CONTROLLERS.keys()
}

# Function to initialize shifts for a team
def initialize_team_schedule(team_name):
    return {
        'team': team_name,
        'shifts': TEAM_CYCLE,  # Assign predefined cycle
        'controllers': []      # Controllers assigned to this team
    }



In [None]:
controller_requirements = {
    0: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, True, False, True]},
    1: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, False, True, True]},
    2: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, False]},
    3: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, False, True, True, True]},
    4: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, False, True, True]},
    5: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, False, True]},
    6: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    7: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, True, False, True, False]},
    8: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, True]},
    9: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, False, True, True, False, True]},
    10: {'medical_valid': True, 'english_valid': False, 'recency': [True, False, True, True, True, True, True]},
    11: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, False, True, True, False]},
    12: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, True]},
    13: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, True, False, True, True]},
    14: {'medical_valid': True, 'english_valid': True, 'recency': [True, False, True, True, True, False, True]},
    15: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    16: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, True, False]},
    17: {'medical_valid': False, 'english_valid': True, 'recency': [True, True, True, False, True, True, True]},
    18: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, True, False, True, True]},
    19: {'medical_valid': True, 'english_valid': True, 'recency': [False, True, True, True, True, False, True]},
    20: {'medical_valid': True, 'english_valid': False, 'recency': [True, True, False, True, True, True, True]},
    21: {'medical_valid': False, 'english_valid': True, 'recency': [True, False, True, True, True, True, False]},
    22: {'medical_valid': True, 'english_valid': True, 'recency': [True, True, True, True, True, False, True]}
}

In [None]:
controller_map = {letter: idx for idx, letter in enumerate('ABCDEFGHIJKLMNOPQRSTUVW')}
numerical_roster = np.vectorize(lambda x: controller_map.get(x, 0))(weekly_roster)

# Randomly set some controller values to 0 to simulate unassigned slots
num_positions_to_clear = 20  # Number of positions to set to 0
days, slots, positions = numerical_roster.shape

for _ in range(num_positions_to_clear):
    day = random.randint(0, days - 1)
    slot = random.randint(0, slots - 1)
    position = random.randint(0, positions - 1)
    numerical_roster[day, slot, position] = 0

In [None]:
weekly_roster

array([[['S', 'T', 'O', 'H', 'Z', 'B', 'E'],
        ['S', 'T', 'O', 'H', 'Z', 'B', 'E'],
        ['S', 'T', 'Y', 'S', 'Y', 'E', 'E'],
        ['X', 'O', 'H', 'S', 'Y', 'E', 'Q'],
        ['X', 'O', 'H', 'O', 'T', 'E', 'Q'],
        ['X', 'O', 'H', 'O', 'T', 'E', 'Q'],
        ['K', 'X', 'S', 'O', 'Z', 'H', 'M'],
        ['K', 'X', 'S', 'X', 'Z', 'H', 'M'],
        ['K', 'X', 'S', 'X', 'H', 'H', 'M'],
        ['K', 'X', 'O', 'X', 'H', 'E', 'W'],
        ['Y', 'K', 'O', 'Z', 'X', 'E', 'W'],
        ['Y', 'K', 'O', 'Z', 'X', 'E', 'W'],
        ['Y', 'K', 'O', 'Z', 'Y', 'E', 'T'],
        ['Y', 'Y', 'R', 'Z', 'Y', 'Z', 'T'],
        ['S', 'Y', 'R', 'Y', 'Y', 'Z', 'T'],
        ['S', 'Y', 'R', 'Y', 'Y', 'Z', 'Z']],

       [['S', 'Y', 'R', 'Y', 'Y', 'Z', 'T'],
        ['S', 'Y', 'R', 'Y', 'Y', 'Z', 'Z'],
        ['S', 'T', 'O', 'H', 'Z', 'B', 'E'],
        ['S', 'T', 'O', 'H', 'Z', 'B', 'E'],
        ['S', 'T', 'Y', 'S', 'Y', 'E', 'E'],
        ['X', 'O', 'H', 'S', 'Y', 'E', 'Q'],
        

In [None]:
import numpy as np

NUM_POSITIONS = 7
NUM_SLOTS = 16
NUM_DAYS = 7

def create_weekly_roster():
    roster = np.full((NUM_DAYS, NUM_SLOTS, NUM_POSITIONS), '', dtype='<U1')
    base_pattern = {
        'APP':      list('SSSXXXKKKKYYYYSS'),
        'CWL_DIR':  list('TTTOOOXXXXKKKYYY'),  # 16 shifts
        'BKN_DIR':  list('OOYHHHSSSOOOORRR'),  # 16 shifts
        'CWL': list('HHSSOOOXXXZZZZYY'),  # 16 shifts
        'BKH_DEPS': list('ZZYYTTZZHHXXYYYY'),  # 16 shifts
        'ADC':    list('BBEEEEHHHEEEEZZZ'),  # 16 shifts
        'GND':      list('EEEQQQMMMWWWTTTZ')
    }

    for pos, pattern in enumerate(base_pattern.values()):
        roster[0, :, pos] = pattern

    for day in range(1, NUM_DAYS):
        roster[day] = np.roll(roster[day-1], 2, axis=0)
    return roster

from datetime import timedelta, datetime
def parse_time(time_str):
    return datetime.strptime(time_str, "%H:%M")
def calculate_roster_metrics(roster):
    metrics = {
        'controller_duty_hours': np.zeros(23),  # For 23 controllers
        'position_coverage': np.zeros((NUM_DAYS, NUM_POSITIONS)),
        'shift_transitions': np.zeros(23),
        'violations': {
            'continuous_duty': 0,
            'weekly_hours': 0,
            'monthly_hours': 0,
            'consecutive_days': 0,
            'mandatory_break': 0,
            'night_duty_break': 0,
            'min_break_between_shifts': 0,
            'recency': 0
        }
    }
    controller_map = {letter: idx for idx, letter in enumerate('ABCDEFGHIJKLMNOPQRSTUVW')}
    controller_schedule = {idx: {'total_hours': 0, 'days_worked': 0, 'last_shift_end': None} for idx in range(23)}
    recency_counter = {idx: 0 for idx in range(10)}

    # Calculate metrics
    for day in range(NUM_DAYS):
        daily_controllers = set()
        for position in range(NUM_POSITIONS):
            for slot in range(NUM_SLOTS):
                controller = roster[day, slot, position]
                if controller in controller_map:
                    controller_id = controller_map[controller]
                    metrics['controller_duty_hours'][controller_id] += 1.5
                    controller_schedule[controller_id]['total_hours'] += 1.5
                    recency_counter[position] += 1 if controller_schedule[controller_id]['total_hours'] > 0 else 0
                    if recency_counter[position] < 2:
                        metrics['violations']['recency'] += 1
                    if slot == 0:
                        controller_schedule[controller_id]['days_worked'] += 1
                    if controller_schedule[controller_id]['last_shift_end']:
                        last_end = controller_schedule[controller_id]['last_shift_end']
                        current_start = timedelta(hours=day*24 + slot*1.5)
                        if (current_start - last_end).total_seconds() / 3600 < DUTY_LIMITS['MIN_BREAK_BETWEEN_SHIFTS']:
                            metrics['violations']['min_break_between_shifts'] += 1
                    if slot >= parse_time(SHIFTS['NIGHT']['start']).hour // 1.5:
                        if day + 2 < NUM_DAYS and slot < parse_time(SHIFTS['MORNING']['start']).hour // 1.5:
                            metrics['violations']['night_duty_break'] += 1

                    controller_schedule[controller_id]['last_shift_end'] = timedelta(hours=day*24 + (slot+1)*1.5)
                    metrics['position_coverage'][day, position] += 1

    for controller_id in range(23):
        total_hours = controller_schedule[controller_id]['total_hours']
        days_worked = controller_schedule[controller_id]['days_worked']

        if total_hours > DUTY_LIMITS['WEEKLY_DUTY_MAX']:
            metrics['violations']['weekly_hours'] += 1
        if total_hours > DUTY_LIMITS['MONTHLY_DUTY_MAX']:
            metrics['violations']['monthly_hours'] += 1
        if days_worked > DUTY_LIMITS['CONSECUTIVE_DAYS_MAX']:
            metrics['violations']['consecutive_days'] += 1

    result_array = np.concatenate([
        metrics['controller_duty_hours'],
        metrics['position_coverage'].flatten(),
        [metrics['violations']['continuous_duty']],
        [metrics['violations']['weekly_hours']],
        [metrics['violations']['monthly_hours']],
        [metrics['violations']['consecutive_days']],
        [metrics['violations']['mandatory_break']],
        [metrics['violations']['night_duty_break']],
        [metrics['violations']['min_break_between_shifts']],
        [metrics['violations']['recency']]
    ])
    return result_array

weekly_roster = create_weekly_roster()
metrics = calculate_roster_metrics(weekly_roster)

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)
positions = ['APP', 'CWL_DIR', 'BKN_DIR', 'CWL', 'BKH_DEPS', 'ADC', 'GND']
time_slots = [f"S{i+1}" for i in range(NUM_SLOTS)]

roster = create_weekly_roster()
weekly_schedule = []
for day in range(NUM_DAYS):
    daily_schedule = pd.DataFrame(roster[day].T, index=positions, columns=time_slots)
    weekly_schedule.append(daily_schedule)
for day, schedule in enumerate(weekly_schedule, start=1):
    print(f"Day {day} Schedule:")
    print(schedule)
    print("\n" + "="*50 + "\n")


Day 1 Schedule:
         S1 S2 S3 S4 S5 S6 S7 S8 S9 S10 S11 S12 S13 S14 S15 S16
APP       S  S  S  X  X  X  K  K  K   K   Y   Y   Y   Y   S   S
CWL_DIR   T  T  T  O  O  O  X  X  X   X   K   K   K   Y   Y   Y
BKN_DIR   O  O  Y  H  H  H  S  S  S   O   O   O   O   R   R   R
CWL       H  H  S  S  O  O  O  X  X   X   Z   Z   Z   Z   Y   Y
BKH_DEPS  Z  Z  Y  Y  T  T  Z  Z  H   H   X   X   Y   Y   Y   Y
ADC       B  B  E  E  E  E  H  H  H   E   E   E   E   Z   Z   Z
GND       E  E  E  Q  Q  Q  M  M  M   W   W   W   T   T   T   Z


Day 2 Schedule:
         S1 S2 S3 S4 S5 S6 S7 S8 S9 S10 S11 S12 S13 S14 S15 S16
APP       S  S  S  S  S  X  X  X  K   K   K   K   Y   Y   Y   Y
CWL_DIR   Y  Y  T  T  T  O  O  O  X   X   X   X   K   K   K   Y
BKN_DIR   R  R  O  O  Y  H  H  H  S   S   S   O   O   O   O   R
CWL       Y  Y  H  H  S  S  O  O  O   X   X   X   Z   Z   Z   Z
BKH_DEPS  Y  Y  Z  Z  Y  Y  T  T  Z   Z   H   H   X   X   Y   Y
ADC       Z  Z  B  B  E  E  E  E  H   H   H   E   E   E   E   Z
GND   

In [None]:
def print_metrics(metrics):
    num_controllers = 23
    num_positions = NUM_POSITIONS
    num_days = NUM_DAYS
    controller_duty_hours = metrics[:num_controllers]
    position_coverage = metrics[num_controllers:num_controllers + num_days * num_positions].reshape(num_days, num_positions)
    violations = metrics[num_controllers + num_days * num_positions:]

    print("Controller Duty Hours (Total hours each controller worked over the week):")
    for i, hours in enumerate(controller_duty_hours, start=1):
        print(f"  Controller {chr(64 + i)}: {hours:.1f} hours")

    print("\nViolations Summary:")
    print(f"  Continuous Duty Violation Count (Exceeds 12 hours): {int(violations[0])}")
    print(f"  Weekly Hours Violation Count (Exceeds 48 hours): {int(violations[1])}")
    print(f"  Monthly Hours Violation Count (Exceeds 190 hours): {int(violations[2])}")
    print(f"  Consecutive Days Worked Violation Count (Exceeds 6 days): {int(violations[3])}")
    print(f"  Mandatory 48-hour Break Violation Count (After 6 consecutive days): {int(violations[4])}")
    print(f"  Night Duty Break Violation Count (Less than 48 hours): {int(violations[5])}")
    print(f"  Minimum Break Between Shifts Violation Count (Less than 12 hours): {int(violations[6])}")
    print(f"  Recency Violation Count (Position not staffed frequently enough): {int(violations[7])}")

print_metrics(metrics)


Controller Duty Hours (Total hours each controller worked over the week):
  Controller A: 0.0 hours
  Controller B: 21.0 hours
  Controller C: 0.0 hours
  Controller D: 0.0 hours
  Controller E: 115.5 hours
  Controller F: 0.0 hours
  Controller G: 0.0 hours
  Controller H: 105.0 hours
  Controller I: 0.0 hours
  Controller J: 0.0 hours
  Controller K: 73.5 hours
  Controller L: 0.0 hours
  Controller M: 31.5 hours
  Controller N: 0.0 hours
  Controller O: 126.0 hours
  Controller P: 0.0 hours
  Controller Q: 31.5 hours
  Controller R: 31.5 hours
  Controller S: 105.0 hours
  Controller T: 84.0 hours
  Controller U: 0.0 hours
  Controller V: 0.0 hours
  Controller W: 31.5 hours

Violations Summary:
  Continuous Duty Violation Count (Exceeds 12 hours): 0
  Weekly Hours Violation Count (Exceeds 48 hours): 6
  Monthly Hours Violation Count (Exceeds 190 hours): 0
  Consecutive Days Worked Violation Count (Exceeds 6 days): 0
  Mandatory 48-hour Break Violation Count (After 6 consecutive day

In [None]:
import numpy as np
NUM_POSITIONS = 7
NUM_CONTROLLERS = 23
NUM_DAYS = 7
NUM_SHIFTS_PER_DAY = 16
NEGATIVE_REWARDS = {
    'CONTINUOUS_12_HOURS': -100,
    'WEEKLY_HOURS': -50,
    'MONTHLY_HOURS': -50,
    'CONSECUTIVE_DAYS': -20,
    'MANDATORY_BREAK': -30,
    'NIGHT_DUTY_BREAK': -40,
    'MIN_BREAK_BETWEEN_SHIFTS': -25,
    'RECENCY': -15
}


In [None]:
from collections import deque


In [None]:
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def push(self, transition):
        self.memory.append(transition)
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [None]:
weekly_roster.shape

(7, 16, 7)

In [None]:
!pip install --upgrade sympy




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import sympy
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()

        # Increased number of neurons and added a few more layers for more complex representations
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)  # Batch Normalization
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)  # Batch Normalization
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)  # Batch Normalization
        self.fc4 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)   # Batch Normalization
        self.fc5 = nn.Linear(64, output_dim)

        # Optional Dropout layer to prevent overfitting
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
      x = torch.relu(self.fc1(x))
      x = self.dropout(x)  # Dropout
      x = torch.relu(self.fc2(x))
      x = self.dropout(x)
      x = torch.relu(self.fc3(x))
      x = self.dropout(x)
      x = torch.relu(self.fc4(x))
      return self.fc5(x)


class EmployeeSchedulingEnv:
    def __init__(self, initial_roster):
        self.initial_roster = initial_roster
        self.state = self.reset()
        self.n_positions=initial_roster.shape[2]
        self.n_slots=initial_roster.shape[1]
        self.n_days=initial_roster.shape[0]

    def reset(self):
        self.state = np.copy(self.initial_roster)
        return self.get_state_representation()

    def get_state_representation(self):
        controller_map = {chr(65 + i): i for i in range(23)}
        numerical_state = np.vectorize(lambda x: controller_map.get(x, -1))(self.state)
        return numerical_state.flatten()
    def check_if_done(self, metrics):
        if np.sum(self.state != 0) == self.n_positions * self.n_slots * self.n_days:
            return True
        return False
    def step(self, action):
        position, slot, controller, day = action

        # Check if the controller has been assigned too many times
        if np.sum(self.state[:, :, :] == controller + 1) >= 20:
            return self.get_state_representation(), -100, False, self.state

        # Assign the controller to the slot
        self.state[position, slot, day] = controller + 1

        # Calculate metrics (e.g., continuous duty, weekly hours, etc.)
        metrics = calculate_roster_metrics(self.state)

        # Penalty for consecutive slots
        if slot > 0 and self.state[position, slot - 1, day] == controller + 1:
            penalty = -50
        else:
            penalty = 0

        # Calculate the base reward based on metrics
        reward = calculate_loss(metrics) + penalty

        # Apply the unfilled slot penalty at the end of the scheduling
        done = self.check_if_done(metrics)
        if done and not np.all(self.state != 0):  # Unfilled slots detected
            unfilled_penalty = -100
            reward += unfilled_penalty

        return self.get_state_representation(), reward, done, metrics




def calculate_loss(metrics):
    violations = {
        'continuous_duty': int(metrics[0]),
        'weekly_hours': int(metrics[1]),
        'monthly_hours': int(metrics[2]),
        'consecutive_days': int(metrics[3]),
        'mandatory_break': int(metrics[4]),
        'night_duty_break': int(metrics[5]),
        'min_break_between_shifts': int(metrics[6]),
        'recency': int(metrics[7])
    }

    # Adjusting negative rewards to be more influential in the loss
    loss = (
        violations['continuous_duty'] * (-1.5) +  # Larger impact on continuous duty
        violations['weekly_hours'] * (-0.5) +  # Higher penalty for exceeding weekly hours
        violations['consecutive_days'] * (-0.75) +  # Larger penalty for consecutive days
        violations['night_duty_break'] * (-0.25) +  # High penalty for not following night duty breaks
        violations['min_break_between_shifts'] * (-2) +  # Larger penalty for minimum breaks
        violations['recency'] * (-1)  # Small penalty for recency violations
    )
    return loss




import matplotlib.pyplot as plt

# Initialize a list to store loss values
losses = []

def train_dqn(env, dqn, target_dqn, optimizer, memory, batch_size, gamma):
    if len(memory) < batch_size:
        return

    transitions = memory.sample(batch_size)
    batch = list(zip(*transitions))

    state_batch = torch.tensor(np.array(batch[0]), dtype=torch.float32).view(batch_size, -1)
    action_batch = torch.tensor(np.array(batch[1]), dtype=torch.long)
    reward_batch = torch.tensor(np.array(batch[2]), dtype=torch.float32)
    next_state_batch = torch.tensor(np.array(batch[3]), dtype=torch.float32).view(batch_size, -1)
    metrics = calculate_roster_metrics(env.state)

    # Calculate loss and update weights
    loss = torch.tensor(calculate_loss(metrics), dtype=torch.float32, requires_grad=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Append the current loss value to the list
    losses.append(loss.item())

# After training, plot the losses




# Training loop
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
batch_size = 64
target_update_freq = 5
memory_capacity = 100
n_episodes = 100

epsilon_decay = 0.89  # Slower decay for more exploration
learning_rate = 5e-4  # Slightly increased learning rate


env = EmployeeSchedulingEnv(weekly_roster)
input_dim = 7*7*16  # Flattened state size for 7 days
output_dim = 7*7*16  # Number of possible actions
dqn = DQN(input_dim, output_dim)
target_dqn = DQN(input_dim, output_dim)
target_dqn.load_state_dict(dqn.state_dict())
optimizer = optim.Adam(dqn.parameters(), lr=learning_rate)
memory = ReplayMemory(memory_capacity)
rewardsper_epoch = []

for episode in range(n_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(16):
        # Epsilon-greedy action selection
        if random.random() < epsilon:
            action = (random.randint(0,1), t, random.randint(0, 22), random.randint(0, 6))  #pos, con, day
        else:
            with torch.no_grad():
                q_values = dqn(torch.tensor(state.flatten(), dtype=torch.float32))
                action_idx = q_values.argmax().item()
                position = action_idx // (len(controllers) * env.n_days)
                controller = (action_idx // env.n_days) % len(controllers)
                day = action_idx % env.n_days
                action = (position, t, controller, day)

        # Update this line to unpack all four values
        next_state, reward, done, metrics = env.step(action)
        print(reward)
        print(env.state)
        if slot > 0 and env.state[position, slot - 1, day] == controller :
            penalty = -50
        else:
            penalty = 0
        print(penalty)
        reward += penalty

        memory.push((state, action, reward, next_state))
        state = next_state
        total_reward += reward

        train_dqn(env, dqn, target_dqn, optimizer, memory, batch_size, gamma)

        if done:
            break

    # Epsilon decay
    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    # Update target network
    if episode % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())

    print(f"Episode {episode + 1}/{n_episodes}, Total Reward: {total_reward}")
    rewardsper_epoch.append(total_reward)
plt.plot(losses)
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.title('Loss Trend During Training')
plt.show()



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  ['X' 'O' 'H' 'O' 'T' 'E' 'Q']
  ['X' 'O' 'H' 'O' 'T' 'E' 'Q']
  ['K' 'X' 'S' 'O' 'Z' 'H' 'M']
  ['K' 'X' 'S' 'X' 'Z' 'H' 'M']]

 [['K' 'X' 'S' 'O' 'Z' 'H' 'M']
  ['K' 'X' 'S' 'X' 'Z' 'H' 'M']
  ['K' 'X' 'S' 'X' 'H' 'H' 'M']
  ['K' 'X' 'O' 'X' 'H' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'X' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'X' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'Y' 'E' 'T']
  ['Y' 'Y' 'R' 'Z' 'Y' 'Z' 'T']
  ['S' 'Y' 'R' 'Y' 'Y' 'Z' 'T']
  ['S' 'Y' 'R' 'Y' 'Y' 'Z' 'Z']
  ['S' 'T' 'O' 'H' 'Z' 'B' 'E']
  ['S' 'T' 'O' 'H' 'Z' 'B' 'E']
  ['S' 'T' 'Y' 'S' 'Y' 'E' 'E']
  ['X' 'O' 'H' 'S' 'Y' 'E' 'Q']
  ['X' 'O' 'H' 'O' 'T' 'E' 'Q']
  ['X' 'O' 'H' 'O' 'T' 'E' 'Q']]

 [['X' 'O' 'H' 'O' 'T' 'E' 'Q']
  ['X' 'O' 'H' 'O' 'T' 'E' 'Q']
  ['K' 'X' 'S' 'O' 'Z' 'H' 'M']
  ['K' 'X' 'S' 'X' 'Z' 'H' 'M']
  ['K' 'X' 'S' 'X' 'H' 'H' 'M']
  ['K' 'X' 'O' 'X' 'H' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'X' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'X' 'E' 'W']
  ['Y' 'K' 'O' 'Z' 'Y' 'E' 'T']
  [

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)
positions = ['APP', 'CWL_DIR', 'BKN_DIR', 'CWL', 'BKH_DEPS', 'ADC', 'GND']
time_slots = [f"S{i+1}" for i in range(NUM_SLOTS)]
def print_learned_roster(dqn, env, controllers):
    weekly_schedule = []
    state = env.reset()
    learned_roster = np.zeros_like(env.initial_roster, dtype=int)
    for day in range(env.n_days):
        for slot in range(env.n_slots):
            state_tensor = torch.tensor(state.flatten(), dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                q_values = dqn(state_tensor)
            action_idx = q_values.argmax().item()
            position = action_idx // (len(controllers) * env.n_days)
            controller = (action_idx // env.n_days) % len(controllers)
            day_of_week = action_idx % env.n_days
            learned_roster[position, slot, day_of_week] = controller + 1
            next_state, _, _, _ = env.step((position, slot, controller, day_of_week))
            state = next_state

    for day in range(env.n_days):
        daily_schedule = pd.DataFrame(learned_roster[day].T, index=positions, columns=time_slots)
        weekly_schedule.append(daily_schedule)
    for day, schedule in enumerate(weekly_schedule, start=1):
        print(f"Day {day} Schedule:")
        print(schedule)
        print("\n" + "="*50 + "\n")

print_learned_roster(dqn, env, controllers)
