In [581]:
from ortools.sat.python import cp_model
import random
from collections import defaultdict

In [582]:
from typing import List

class NursePatientMatcher:
    
    def __init__(self, num_patients, num_nurses, max_patients, min_patients, patient_types, acuities, patient_times):
        self.p: int = num_patients
        self.n: int = num_nurses
        self.max_patients_per_nurse: int = max_patients
        # self.min_patients_per_nurse: int = min_patients
        self.patients_of_each_type: List[int] = patient_types
        self.patient_nurse_acuities: List[List[int]] = acuities
        self.patient_times: List[int] = patient_times

In [583]:
def create_variables(self):
    model: cp_model.CpModel = self.model
    n: int = self.n
    p: int = self.p
    patients_per_types = self.patients_of_each_type
    max_patients: int = self.max_patients_per_nurse
    self.x = {}
    self.patients_per_nurse = {}
    self.patient_types = {}
    self.penalty = model.NewIntVar(0,1,'penalty')
    for nurse in range(n):
        for patient in range(p):
            self.x[nurse, patient] = model.NewBoolVar(f'x[Nurse: {nurse}, Patient: {patient}]')
    
    for nurse in range(n):
        self.patients_per_nurse[nurse] = model.NewIntVar(0, max_patients, f'Patients for Nurse {nurse}')
    
    prev = 0
    t = 0
    for patient in range(p):
        num_patients_of_type = patients_per_types[t]
        if patient >= prev + num_patients_of_type:
            prev = prev + num_patients_of_type
            t += 1
        self.patient_types[patient] = t

    self.patients_per_time = {}
    for i in range(10):
        self.patients_per_time[i] = [patient for patient in range(p) if self.patient_times[patient] == i]


    self.nurse_sched = []
    for nurse in range(n):
        self.nurse_sched.append([model.NewIntVar(0, p, f'Hour {i} for nurse {nurse}') for i in range(10)])
            
        

# Register this method with the solver class
NursePatientMatcher.create_variables = create_variables

In [584]:
def bound_patients_per_nurse(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    p: int = self.p
    max_patients: int = self.max_patients_per_nurse
    # min_patients: int = self.min_patients_per_nurse
    for nurse in range(n):
        model.Add(sum(x[nurse,patient] for patient in range(p)) <= max_patients)
        # model.Add(sum(x[nurse,patient] for patient in range(p)) >= min_patients)

# Register this method with the solver class
NursePatientMatcher.bound_patients_per_nurse = bound_patients_per_nurse

In [585]:
def one_nurse_per_patient(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    p: int = self.p
    for patient in range(p):
        model.Add(sum(x[nurse,patient] for nurse in range(n)) == 1)


# Register this method with the solver class
NursePatientMatcher.one_nurse_per_patient = one_nurse_per_patient

In [586]:
def fill_nurse_schedule(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    p: int = self.p
    nurse_sched = self.nurse_sched
    patients_per_time = self.patients_per_time

    for nurse in range(n):
        for i in range(10):
            model.Add(nurse_sched[nurse][i] == sum(x[nurse, patient] for patient in patients_per_time[i]))
            model.Add(nurse_sched[nurse][i] <= 1)

# Register this method with the solver class
NursePatientMatcher.fill_nurse_schedule = fill_nurse_schedule
    

In [587]:
def track_patients_per_nurse(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    p: int = self.p
    patients_per_nurse = self.patients_per_nurse
    for nurse in range(n):
        model.Add(sum(x[nurse,patient] for patient in range(p)) == patients_per_nurse[nurse])


# Register this method with the solver class
NursePatientMatcher.track_patients_per_nurse = track_patients_per_nurse

In [588]:
def at_most_one_double_shift(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    nurse_sched = self.nurse_sched
    for nurse in range(n):
        num_double_tracker = [model.NewBoolVar(f'Double shifts for nurse {nurse}') for _ in range(9)]
        for i in range(1, 10):
            model.Add((nurse_sched[nurse][i] + nurse_sched[nurse][i-1] == 2)).OnlyEnforceIf(num_double_tracker[i-1])
            model.Add((nurse_sched[nurse][i] + nurse_sched[nurse][i-1] < 2)).OnlyEnforceIf(num_double_tracker[i-1].Not())
        
        pairs = [[model.NewBoolVar(f'Shift pair {j}, {k} for nurse {nurse}') for k in range(j+1, 9)] for j in range(11)]

        for j in range(9):
            for k in range(j+1, 9):
                model.AddBoolOr([num_double_tracker[j].Not(), num_double_tracker[k].Not()]).OnlyEnforceIf(pairs[j][k-j-1])
                model.AddBoolAnd([num_double_tracker[j], num_double_tracker[k]]).OnlyEnforceIf(pairs[j][k-j-1].Not())

        flattened_pairs = []
        for pair in pairs:
            flattened_pairs.extend(pair)
        
        model.AddBoolAnd(flattened_pairs)
        
# Register this method with the solver class
NursePatientMatcher.at_most_one_double_shift = at_most_one_double_shift

In [589]:
def double_shift_penalty(self):
    x = self.x
    model: cp_model.CpModel = self.model
    n: int = self.n
    nurse_sched = self.nurse_sched
    self.penalty = model.NewIntVar(0, 10*n, 'penalty')
    num_double_tracker = [[model.NewBoolVar(f'Double shifts for nurse {nurse}') for _ in range(9)] for nurse in range(n)]
    for nurse in range(n):
        for i in range(1, 10):
            model.Add((nurse_sched[nurse][i] + nurse_sched[nurse][i-1] == 2)).OnlyEnforceIf(num_double_tracker[nurse][i-1])
            model.Add((nurse_sched[nurse][i] + nurse_sched[nurse][i-1] < 2)).OnlyEnforceIf(num_double_tracker[nurse][i-1].Not())
        
    model.Add(self.penalty == sum(sum(num_double_tracker[nurse][i] for i in range(9)) for nurse in range(n)))
        
# Register this method with the solver class
NursePatientMatcher.double_shift_penalty = double_shift_penalty

In [590]:
def minimize_objectives(self):
    x = self.x
    model: cp_model.CpModel = self.model
    patients_per_nurse = self.patients_per_nurse
    n: int = self.n
    p: int = self.p
    max_patients: int = self.max_patients_per_nurse
    acuities = self.patient_nurse_acuities
    penalty = self.penalty

    absolute_diff = [model.NewIntVar(0, max_patients, "Absolute difference between nurse n workload and average workload") for _ in range(n)]

    for nurse in range(n):
        difference = model.NewIntVar(-max_patients, max_patients, 'Difference between nurse n workload and average workload')
        model.Add(difference == patients_per_nurse[nurse] - int(p/n))
        model.AddAbsEquality(absolute_diff[nurse], difference)
    

    acuity_score = sum(acuities[nurse][self.patient_types[patient]] * self.x[nurse, patient] for nurse in range(n) for patient in range(p))

    acuity_ub = sum(max(acuities[nurse]) for nurse in range(n))
    acuity_lb = sum(min(acuities[nurse]) for nurse in range(n))

    M = acuity_ub - acuity_lb + 1

    model.Minimize(M*sum(absolute_diff) - acuity_score + M*penalty)
    

# Register this method with the solver class
NursePatientMatcher.minimize_objectives = minimize_objectives

In [591]:
def solve(self):
    self.model = cp_model.CpModel()
    self.solver = cp_model.CpSolver()
    self.create_variables()
    self.bound_patients_per_nurse()
    self.one_nurse_per_patient()
    self.fill_nurse_schedule()
    self.track_patients_per_nurse()
    self.at_most_one_double_shift()
    # self.double_shift_penalty()
    self.minimize_objectives()
    if self.solver.Solve(self.model) == cp_model.OPTIMAL:
        print('Solved!')
        for nurse in range(self.n):
            print(f'Patients for nurse {nurse}: {self.solver.Value(self.patients_per_nurse[nurse])}')
            for i in range(10):
                print(f'{self.solver.Value(self.nurse_sched[nurse][i])}', end=' ')
            print()
        
    else:   
        raise ValueError('Modeling error!')

# Register this method with the solver class
NursePatientMatcher.solve = solve

In [592]:
def read_data_and_solve(path):
    with open(path) as f:
        lines = f.readlines()
        num_nurses, num_patients, num_patient_types = lines[0].split(' ')
        num_nurses, num_patients, num_patient_types = int(num_nurses), int(num_patients), int(num_patient_types)

        min_patients, max_patients = lines[1].split(' ')
        min_patients, max_patients = int(min_patients), int(max_patients)

        patient_types = lines[3].strip().split(' ')
        patient_types = [int(x) for x in patient_types]

        acuities = []
        for i in range(num_nurses):
            nurse_acuities = lines[5 + i].strip().split(' ')
            acuities.append([int(x) for x in nurse_acuities])
    patient_times = [9, 2, 0, 9, 8, 4, 9, 5, 5, 2, 0, 6, 9, 4, 6, 2, 4, 6, 8, 2, 8, 3, 5, 7, 7, 1, 5, 0, 0, 0, 0, 7, 8, 4, 7, 6, 7, 9, 7, 3, 1, 2, 7, 7, 4, 0, 9, 4, 0, 6, 1, 8, 9, 4, 5, 1, 6, 6, 7, 1, 5, 5, 2, 6, 7, 3, 5, 0, 3, 5, 1, 7, 0, 8, 8, 8, 0, 6, 2, 7, 6, 4, 9, 4, 5, 0, 5, 6, 5, 8, 2, 1, 3, 9, 0, 6, 6, 8, 1, 8, 8, 0, 2, 2, 8, 3, 4, 3, 7, 3, 2, 8, 3, 7, 8, 1, 7, 6, 1, 0, 8, 1, 3, 3, 4, 9, 1, 4, 5, 6, 7, 2, 8, 5, 5, 2, 3, 1, 4, 1, 6, 7, 0, 4, 9, 1, 0, 8, 0, 9, 7, 7, 1, 5, 0, 9, 5, 1, 9, 0, 8, 8, 6, 2, 1, 5, 6, 8, 9, 5, 5, 4, 1, 4, 8, 0, 5, 7, 7, 0, 2, 8, 1, 5, 7, 8, 5, 8, 4, 4, 4, 0, 4, 4, 4, 2, 1, 3, 6, 7, 1, 9, 8, 1, 1, 9, 5, 4, 1, 4, 5, 5, 4, 6, 4, 0, 4, 0, 5, 4, 3, 4, 2, 7, 8, 1, 4, 1, 9, 8, 8, 5, 1, 2, 5, 4, 9, 6, 2, 6, 6, 6, 1, 1, 3, 1, 7, 6, 9, 7, 0, 2, 3, 0, 0, 9, 0, 7, 6, 9, 7, 4, 6, 3, 2, 9, 4, 7, 3, 1, 4, 9, 3, 6, 0, 7, 1, 7, 4, 6, 0, 4, 4, 4, 2, 9, 4, 5, 4, 3, 7, 4, 7, 3, 6, 1, 4, 8, 9, 5, 8, 6, 7, 5, 7, 4, 6, 6, 7, 4, 7, 4, 8, 4, 9, 1, 0, 5, 8, 4, 0, 8, 3, 1, 9, 2, 8, 7, 4, 8, 0, 1, 7, 0, 7, 5, 1, 0, 8, 4, 8, 3, 5, 6, 6, 6, 2, 6, 1, 1, 7, 9, 1, 8, 9, 7, 3, 5, 4, 5, 3, 7, 6, 7, 4, 7, 1, 3, 1, 7, 0, 5, 0, 5, 5, 1, 6, 9, 5, 1, 7, 4, 3, 7, 2, 7, 7, 0, 0, 9, 9, 0, 4, 3, 1, 8, 1, 2, 6, 4, 6, 5, 9, 8, 0, 6, 1, 0, 9, 1, 3, 0, 1, 3, 6, 4, 2, 5, 1, 8, 8, 6, 4, 0, 6, 1, 2, 6, 2, 3, 2, 7, 2, 9, 9, 4, 4, 5, 7, 4, 3, 9, 0, 7, 7, 1, 4, 4, 8, 6, 4, 6, 1, 9, 3, 9, 5, 9, 1, 3, 3, 3, 6, 7, 3, 0, 7, 4, 5, 9, 1, 1, 4, 6, 4, 4, 6, 8, 8, 9, 6, 0, 3, 0, 4, 3, 1, 9, 2, 5, 5, 5, 9, 1, 3, 4, 9, 5, 3, 8, 7, 3, 4, 6, 5, 2, 5, 5, 6, 7, 4, 6, 5, 5, 2, 0, 4, 4, 5, 2, 5, 2, 4, 0, 4, 8, 7, 4, 6, 4, 2, 5, 0, 9, 1, 5, 9, 8, 8, 7, 3, 4, 7, 6, 4, 3, 2, 8, 3, 0, 6, 7, 3, 1, 7, 1, 7, 3, 1, 0, 0, 1, 3, 4, 1, 0, 7, 5, 8, 9, 9, 6, 3, 0, 9, 6, 8, 4, 5, 4, 5, 6, 2, 6, 5, 6, 5, 8, 8, 9, 4, 1, 3, 4, 9, 3, 7, 6, 9, 2, 6, 9, 0, 9, 1, 6, 4, 5, 6, 1, 0, 9, 1, 3, 9, 8, 4, 4, 7, 8, 9, 2, 9, 7, 7, 3, 1, 5, 2, 3, 2, 7, 5, 2, 3, 8, 9, 4, 3, 0, 8, 7, 8, 0, 3, 0, 4, 9, 9, 9, 5, 9, 8, 5, 8, 6, 3, 2, 5, 4, 9, 2, 7, 3, 0, 8, 2, 5, 2, 2, 9, 7, 2, 3, 5, 1, 7, 4, 7, 3, 0, 1, 6, 3, 3, 1, 4, 0, 5, 3, 4, 6, 6, 6, 8, 6, 6, 5, 6, 9]
    # patient_times = [9, 2, 10, 4, 8, 9, 9, 0, 7, 11, 2, 0, 0, 0, 0, 7, 4, 3, 2, 8, 3, 10, 7, 3, 9, 2, 6, 10]
    # patient_times = [9, 3, 3, 0, 5, 7, 6, 5, 4, 7, 5, 8, 3, 2, 10, 2, 8, 0, 0, 4, 3, 7, 10, 11, 11, 3, 10, 7]
    # patient_times = []
    # times_seen = defaultdict(int)
    # for _ in range(num_patients):
    #     new_time = random.randint(0,9)
    #     while not (not times_seen or times_seen[new_time] < num_nurses):
    #         new_time = random.randint(0,9)
    #     times_seen[new_time] += 1
    #     patient_times.append(new_time)
        
    print(patient_times)

    matcher = NursePatientMatcher(num_patients, num_nurses, max_patients, min_patients, patient_types, acuities, patient_times)
    soln = matcher.solve()

In [593]:
read_data_and_solve('data/hospital_data.txt')

[9, 2, 0, 9, 8, 4, 9, 5, 5, 2, 0, 6, 9, 4, 6, 2, 4, 6, 8, 2, 8, 3, 5, 7, 7, 1, 5, 0, 0, 0, 0, 7, 8, 4, 7, 6, 7, 9, 7, 3, 1, 2, 7, 7, 4, 0, 9, 4, 0, 6, 1, 8, 9, 4, 5, 1, 6, 6, 7, 1, 5, 5, 2, 6, 7, 3, 5, 0, 3, 5, 1, 7, 0, 8, 8, 8, 0, 6, 2, 7, 6, 4, 9, 4, 5, 0, 5, 6, 5, 8, 2, 1, 3, 9, 0, 6, 6, 8, 1, 8, 8, 0, 2, 2, 8, 3, 4, 3, 7, 3, 2, 8, 3, 7, 8, 1, 7, 6, 1, 0, 8, 1, 3, 3, 4, 9, 1, 4, 5, 6, 7, 2, 8, 5, 5, 2, 3, 1, 4, 1, 6, 7, 0, 4, 9, 1, 0, 8, 0, 9, 7, 7, 1, 5, 0, 9, 5, 1, 9, 0, 8, 8, 6, 2, 1, 5, 6, 8, 9, 5, 5, 4, 1, 4, 8, 0, 5, 7, 7, 0, 2, 8, 1, 5, 7, 8, 5, 8, 4, 4, 4, 0, 4, 4, 4, 2, 1, 3, 6, 7, 1, 9, 8, 1, 1, 9, 5, 4, 1, 4, 5, 5, 4, 6, 4, 0, 4, 0, 5, 4, 3, 4, 2, 7, 8, 1, 4, 1, 9, 8, 8, 5, 1, 2, 5, 4, 9, 6, 2, 6, 6, 6, 1, 1, 3, 1, 7, 6, 9, 7, 0, 2, 3, 0, 0, 9, 0, 7, 6, 9, 7, 4, 6, 3, 2, 9, 4, 7, 3, 1, 4, 9, 3, 6, 0, 7, 1, 7, 4, 6, 0, 4, 4, 4, 2, 9, 4, 5, 4, 3, 7, 4, 7, 3, 6, 1, 4, 8, 9, 5, 8, 6, 7, 5, 7, 4, 6, 6, 7, 4, 7, 4, 8, 4, 9, 1, 0, 5, 8, 4, 0, 8, 3, 1, 9, 2, 8, 7, 4, 8, 0, 1, 7, 

KeyboardInterrupt: 