In [1]:
import utils
from ortools.sat.python import cp_model
import pandas as pd
from params import *
import datetime
from itertools import islice
from workalendar.asia import Malaysia
malaysia = Malaysia()

In [2]:
shifts = [am, ams, pm, n]
leaves = [no, do, al]
date_range = utils.date_range(start=(2022,7,4), end=(2022,7,17))

In [3]:
df = pd.read_excel('data.xlsx', index_col=0)
prev_roster_date_range = list(df.columns.to_pydatetime())
prev_roster = utils.get_data_to_tuple('data.xlsx')

In [4]:
class Model:
    def __init__(self, _nurses, _roles, _shifts, _leaves, _date_range, _prev_roster,_prev_roster_date_range=None):
        self.model = cp_model.CpModel()
        self._nurses = _nurses
        self._roles = _roles
        self._shifts = _shifts
        self._leaves = _leaves
        self._date_range = _date_range
        self._prev_roster = _prev_roster
        self.collect_penalty()
        
        self._prev_roster_date_range = _prev_roster_date_range
        self._total_date_range = _date_range if _prev_roster_date_range is None else _prev_roster_date_range + _date_range
        self._works = {}
        self._slots = _shifts + _leaves
        self.obj_bool_vars_max = []
        self.obj_bool_coeffs_max = []
        self.obj_bool_vars_min = []
        self.obj_bool_coeffs_min = []
        
        self.obj_int_vars_max = []
        self.obj_int_coeffs_max = []
        self.obj_int_vars_min = []
        self.obj_int_coeffs_min = []
        self.logs = ''

################################################################################################
######################################        Init       #######################################
################################################################################################

        for nurse in self._nurses:
            for s in self._slots:
                for d in self._total_date_range:
                    self._works[(nurse, s, d)] = self.model.NewBoolVar(f'work_{nurse}_{s}_{d}')
                    
        for nurse in self._nurses:
            for d in self._date_range:
                self.model.AddExactlyOne(self._works[(nurse, s, d)] for s in self._slots)
                
################################################################################################
######################################      Logger       #######################################
################################################################################################

    def logger(self, text):
        self.logs += f'\n{text}'

################################################################################################
######################################      Off day      #######################################
################################################################################################
    
    def off_day_per_week(self, day=1):
        for nurse in self._nurses:
            for i, a_week in enumerate(utils.chunk(self._date_range, 7)):
                try:
                    self.model.Add(sum(self._works[(nurse, do, d)] for d in a_week) == day)
                except:
                    print(f'no index, {nurse}, {i}')

################################################################################################
######################################     Demand covers    ####################################
################################################################################################
    
    def number_workers_per_shift(self):
        for d in self._date_range:
            # night
            self.model.Add(sum(self._works[(nurse, n, d)] for nurse in self._nurses) == 1)
            self.model.Add(sum(self._works[(nurse, no, d)] for nurse in self._nurses) >= 0)
            if utils.is_saturday(d):
                # am
                self.model.Add(sum(self._works[(nurse, am, d)] for nurse in self._nurses) == 3)
                # pm
                self.model.Add(sum(self._works[(nurse, pm, d)] for nurse in self._nurses) == 1)
                
            if utils.is_sunday(d):
                # am
                self.model.Add(sum(self._works[(nurse, am, d)] for nurse in self._nurses) == 1)
                # pm
                self.model.Add(sum(self._works[(nurse, pm, d)] for nurse in self._nurses) == 1)
            
            if utils.is_weekday(d):
                # am
                self.model.Add(sum(self._works[(nurse, morning, d)] for nurse in self._nurses for morning in [am, ams]) >= 4)
                # pm
                self.model.Add(sum(self._works[(nurse, pm, d)] for nurse in self._nurses) == 3)
                
            # ams
            if utils.is_weekend(d):
                self.model.Add(sum(self._works[(nurse, ams, d)] for nurse in self._nurses) == 0)
            else:
                self.model.Add(sum(self._works[(nurse, ams, d)] for nurse in self._nurses) >= 1)
                self.model.Add(sum(self._works[(nurse, ams, d)] for nurse in self._nurses) <= 2)
                
                    
################################################################################################
###################################    Previous roster init    #################################
################################################################################################
    
    def previous_roster(self, day_1_data: list):
        for nurse in self._nurses:
            for s in self._slots:
                for d in self._total_date_range:
                    if (nurse, s, d) in day_1_data:
                        self.model.Add(self._works[(nurse, s, d)] == 1)
                        # self.model.AddHint(self._works[(nurse, s, d)], 1)
                    # else:
                    #     self.model.Add(self._works[(nurse, s, d)] == 0)
        # for nurse, s, d in day_1_data:
        #     try:
        #         self.model.Add(self._works[(nurse, s, d)] == 1)
        #     except:
        #         self.logger(f'no index, {nurse}, {s}')


################################################################################################
######################################      Requests     #######################################
################################################################################################

    def requests(self):
        for nurse in self._nurses:
            for d in self._date_range:
                self.model.Add(self._works[(nurse, al, d)] == 0)
                
################################################################################################
######################################      Seniority     ######################################
################################################################################################

    def sum_seniority(self, roles):
        senior_nurses = [name for name, role in roles.items() if role == senior]
        for d in self._date_range:
            if utils.is_weekday(d):
                self.model.Add(sum(self._works[(nurse, pm, d)] for nurse in senior_nurses) >= 1)
                self.model.Add(sum(self._works[(nurse, pm, d)] for nurse in senior_nurses) <= 2)
                self.model.Add(sum(self._works[(nurse, morning, d)] for nurse in senior_nurses for morning in [am, ams]) >= 1)
                self.model.Add(sum(self._works[(nurse, morning, d)] for nurse in senior_nurses for morning in [am, ams]) <= 2)
            
################################################################################################
######################################      Sum constraints     ################################
################################################################################################

    def implement_sum_constraint(self, sum_constraints):        
        for sum_const in sum_constraints: # need to build data model
            slot, slot_type, hard_min, soft_min, min_cost, soft_max, hard_max, max_cost = sum_const
            for nurse in self._nurses:
                dates = utils.chunk(self._total_date_range, 7) if slot_type == "weekly" else [self._total_date_range[:28]]
                for index, week in enumerate(dates):
                    self.sum_constraint(
                        nurse, 
                        slot,
                        week,
                        index,
                        hard_min, 
                        soft_min, 
                        min_cost, 
                        soft_max,
                        hard_max, 
                        max_cost
                    )
                            
    def sum_constraint(
        self,
        nurse, 
        slot, 
        date_list, 
        date_list_index,
        hard_min, 
        soft_min, 
        min_cost, 
        soft_max,
        hard_max, 
        max_cost
    ):
        try:
            works = [self._works[(nurse, slot, d)] for d in date_list]
            prefix = f'weekly_sum_constraint({nurse}, {slot}, {date_list_index})'
            variables, coeffs = utils.add_soft_sum_constraint(
                self.model, 
                works, 
                hard_min, 
                soft_min, 
                min_cost, 
                soft_max,
                hard_max, 
                max_cost, 
                prefix
            )
            self.obj_int_vars_min.extend(variables)
            self.obj_int_coeffs_min.extend(coeffs)
        except Exception as e:
            self.logger(f'{e}, outside of range')

################################################################################################
######################################      Objective     ######################################
################################################################################################


    def maximize(self):
        self.model.Maximize(
            sum(self.obj_bool_vars_max[i] * self.obj_bool_coeffs_max[i]
                for i in range(len(self.obj_bool_vars_max))) +
            sum(self.obj_int_vars_max[i] * self.obj_int_coeffs_max[i]
                for i in range(len(self.obj_int_vars_max)))
        )
        
    def minimize(self):
        
        self.model.Minimize(
            sum(self.obj_bool_vars_min[i] * self.obj_bool_coeffs_min[i]
            for i in range(len(self.obj_bool_vars_min))) +
            sum(self.obj_int_vars_min[i] * self.obj_int_coeffs_min[i]
                for i in range(len(self.obj_int_vars_min)))
        )


################################################################################################
######################################      SOLVE         ######################################
################################################################################################
    
    def solve(self, search_all=False):
        self.solver = cp_model.CpSolver()
        self.solver.parameters.enumerate_all_solutions = search_all
        self.solution_printer = cp_model.ObjectiveSolutionPrinter()
        self.status = self.solver.Solve(self.model, self.solution_printer)
        self.solution = {}
        self.current_solution = []
        if self.status == cp_model.OPTIMAL or self.status == cp_model.FEASIBLE:
            print()
            header = '                '
            for d in self._total_date_range:
                header += f'{d.strftime("%d-%m")}\t'
            print(header)
            for nurse in self._nurses:
                schedule = ''
                # for d in self._date_range:
                for d in self._total_date_range:
                    for s in self._slots:
                        if self.solver.BooleanValue(self._works[nurse, s, d]) == 1:
                            self.solution.setdefault(nurse,[])
                            self.solution[nurse].append(s)
                            self.current_solution.append((nurse, s, d))
                            schedule += s + '\t'
                    
                print(f'{nurse}:   \t{schedule}')
            print()
            print('Penalties:')
            for i, var in enumerate(self.obj_bool_vars_max):
                if self.solver.BooleanValue(var):
                    penalty = self.obj_bool_coeffs_max[i]
                    if penalty > 0:
                        print('  %s fulfilled, gain=%i' % (var.Name(), penalty))
                    else:
                        print('  %s violated, penalty=%i' % (var.Name(), -penalty))

            for i, var in enumerate(self.obj_int_vars_min):
                if self.solver.Value(var) > 0:
                    print('  %s violated by %i, linear penalty=%i' %
                          (var.Name(), self.solver.Value(var), self.obj_int_coeffs_min[i]))

        print()
        print('Statistics')
        print('  - status          : %s' % self.solver.StatusName(self.status))
        print('  - conflicts       : %i' % self.solver.NumConflicts())
        print('  - branches        : %i' % self.solver.NumBranches())
        print('  - wall time       : %f s' % self.solver.WallTime())
        print('\n\n\n')
        print(self.logs)          
        
################################################################################################
######################################      Transitions     ####################################
################################################################################################
            
    def implement_shift_transitions(self, transitions: list, strategy, cost):
        for n in self._nurses:
            self.implement_transition_for_each_nurse(n, transitions, strategy, cost)
                
    def implement_transition_for_each_nurse(self, nurse, transitions, strategy, cost):
        prev = len(transitions)+1
        for d in self._date_range:
            transition = []
            for day, slot in transitions.items():
                transition.append((nurse, slot, utils.add_days(d,day)))
            self.implement_transition_according_to_strat(transition, strategy, cost)
            
    def build_negation_list(self, final_list, negation_list, nurse, day):
        for neg in negation_list:
            final_list.append(self._works[(nurse, neg, day)].Not())
        return final_list
    
    def implement_transition_according_to_strat(self, transitions, strategy, cost):
        if strategy == "never":
            (n1, s1, d1), (n2, s2, d2) = transitions
            try:
                self.model.AddBoolOr(self._works[(n1, s1, d1)].Not(), self._works[(n2, s2, d2)].Not())
            except Exception as e:
                self.logger(f'{e}, never transition outside of range')
            
        elif strategy == "always":
            # {0: n,1: n,2: no}
            if len(transitions) == 3:
                (n1, s1, d1), (n2, s2, d2), (n3, s3, d3) = transitions
                
                work1, work2, work3, work4 = [], [], [], []
                for trans in transitions:
                    nurse, shift, day = trans
                    if day == d1:
                        try:
                            work1.append(self._works[trans])
                            work2.append(self._works[trans].Not())
                            work3.append(self._works[trans].Not())
                            work4.append(self._works[trans].Not())
                        except Exception as e:
                            self.logger(f'transition d1, {e}, outside of range')
                    elif day == d2:
                        try:
                            work1.append(self._works[trans])
                            work2.append(self._works[trans])
                            work3.append(self._works[trans].Not())
                            work4.append(self._works[trans].Not())
                        except Exception as e:
                            self.logger(f'transition d2, {e}, outside of range')
                    else:
                        try:
                            work1.append(self._works[trans])
                            work2.append(self._works[(n3, n, d3)])
                            work3.append(self._works[(n3, n, d3)])
                            work4.append(self._works[trans].Not())
                        except Exception as e:
                            self.logger(f'transition d1, {e}, outside of range')
                            
                            
                try:
                    final_works = [work1, work2, work3, work4]
                    self.model.AddBoolOr(final_works)
                
                except Exception as e:
                    self.logger(f'transition model {e}, outside of range')
            else:
                try:
                    t1, t2 = transitions
                    self.model.AddImplication(self._works[t1], self._works[t2])
                    self.model.AddImplication(self._works[t1].Not(), self._works[t2].Not())
                except Exception as e:
                    self.logger(f'{e}, always outside of range')
                
        elif strategy == 'max':
            assert len(transitions) == 2, "maximize transition should be only 2 days"
            (n1, s1, d1), (n2, s2, d2) = transitions
            
            try:
                # transition = [self._works[(n1, s1, d1)], self._works[(n2, s2, d2)]]
                trans_var = self.model.NewBoolVar(f'{n1}, {d1.strftime("%d/%m")}, shift {s1} to {s2}')
                transitions.append(trans_var)
                self.model.AddImplication(self._works[(n1, s1, d1)], self._works[(n2, s2, d2)]).OnlyEnforceIf(trans_var)
                self.model.AddImplication(self._works[(n1, s1, d1)].Not(), self._works[(n2, s2, d2)].Not()).OnlyEnforceIf(trans_var.Not())
                # self.model.AddBoolOr(transitions)

                self.obj_bool_vars_max.append(trans_var)
                self.obj_bool_coeffs_max.append(cost)
            except Exception as e:
                self.logger(f'{e}, max transition outside of range')
                
    def handle_shift_transition(self, transitions):
        for trans, strat, cost in transitions:
            self.implement_shift_transitions(trans, strat, cost)
            
################################################################################################
######################################      Sequence        ####################################
################################################################################################

    def get_night_shift_combinations_for_each_nurse_mode3(self, nurse, d):
        arr = []
        for i in range(3):
            try:
                if i != 2:
                    arr.append(self._works[nurse, n, d+datetime.timedelta(days=i)])
                else:
                    arr.append(self._works[nurse, no, d+datetime.timedelta(days=i)])
            except Exception as e:
                self.logger(f'night shift transition out of range, {e}')
        self.implement_night_shift_combinations_mode3(arr)
    
    def implement_night_shift_combinations_mode3(self, night_shift_arr):
        try:
            if len(night_shift_arr) > 2:
                self.model.AddBoolOr(night_shift_arr[1].Not(), night_shift_arr[2].Not(), night_shift_arr[0])
                self.model.AddImplication(night_shift_arr[0], night_shift_arr[1])
                self.model.AddImplication(night_shift_arr[0], night_shift_arr[2])
            else:
                self.model.AddBoolOr(night_shift_arr[1].Not(), night_shift_arr[0])
                self.model.AddImplication(night_shift_arr[0], night_shift_arr[1])
        except Exception as e:
                self.logger(f'model night shift transition out of range, {e}')
        pass
    
    def handle_night_shift_sequences_mode3(self):
        transition_len = 4
        dates = self._prev_roster_date_range[-4:]+self._date_range
        self.previous_roster(self._prev_roster)
        for nurse in self._nurses:
            all_night_permutation = []
            for d in dates:
                self.get_night_shift_combinations_for_each_nurse_mode3(nurse, d)

    def get_night_shift_combinations_for_each_nurse_mode2(self, nurse, d):
        arr = []
        for i in range(3):
            try:
                if i != 2:
                    arr.append(self._works[nurse, n, d+datetime.timedelta(days=i)])
                else:
                    arr.append(self._works[nurse, no, d+datetime.timedelta(days=i)])
            except Exception as e:
                self.logger(f'night shift transition out of range, {e}')
        return arr
    
    def handle_night_shift_sequences_mode2(self):
        transition_len = 4
        dates = self._prev_roster_date_range[-4:]+self._date_range
        self.previous_roster(self._prev_roster)
        for nurse in self._nurses:
            all_night_permutation = []
            for d in dates:
                comb = self.get_night_shift_combinations_for_each_nurse_mode2(nurse, d)
                all_night_permutation.append(comb)
            try:
                self.model.AddBoolOr(all_night_permutation)
            except Exception as e:
                self.logger(f'nightshift or {e}, outside of range')
            
    def get_night_shift_combinations_for_each_nurse(self, nurse, d):
        d0_n = (nurse, n, d)
        d1_n = (nurse, n, d+datetime.timedelta(days=1))
        d2_n = (nurse, no, d+datetime.timedelta(days=2))
        
        d0_no = (nurse, no, d)
        d1_no = (nurse, no, d+datetime.timedelta(days=1))
        d2_no = (nurse, no, d+datetime.timedelta(days=2))
        d3_no = (nurse, no, d+datetime.timedelta(days=3))
        
        try:
            var = [self._works[d1_n], self._works[d2_n]]
            var2 = [self._works[d1_n].Not(), self._works[d2_n]]
            var3 = [self._works[d1_n].Not(), self._works[d2_n].Not()]
            
            self.model.AddBoolOr([self._works[d2_no]]).OnlyEnforceIf([self._works[d0_n], self._works[d1_n]])
            self.model.AddImplication(self._works[d1_n], self._works[d2_no]).OnlyEnforceIf(self._works[d0_n])
            # self.model.AddImplication(self._wo/rks[d1_n], self._works[d2_no]).OnlyEnforceIf(self._works[d0_n])
            # self.model.AddBoolAnd(var).OnlyEnforceIf(self._works[d0_n].Not())
            # self.model.AddBoolAnd(var2).OnlyEnforceIf(self._works[d0_n].Not())
            # self.model.AddBoolAnd(var3).OnlyEnforceIf(self._works[d0_n].Not())
            
            # self.model.AddBoolOr(self._works[d1_n].Not(), self._works[d2_no].Not(), self._works[d0_n])
            # self.model.AddImplication(self._works[d0_n], self._works[d1_n])
            # self.model.AddImplication(self._works[d0_n], self._works[d2_no])
            
        except Exception as e:
            self.logger(f'night shift transition out of range, {e}')
        
    def handle_night_shift_sequences(self):
        transition_len = 4
        dates = self._date_range
        self.previous_roster(self._prev_roster)
        for nurse in self._nurses:
            # all_permutation = []
            for d in dates:
                self.get_night_shift_combinations_for_each_nurse(nurse, d)

################################################################################################
#########     FAIRNESS----------------Equal allocation        ##################################
################################################################################################

    def fairness_allocation(self, difference=1):
        fairshift = {}
        sum_of_shifts = {}
        num_days = len(self._date_range) 
        for nurse in self._nurses:
            for s in self._shifts:
                sum_of_shifts[(nurse, s)] = self.model.NewIntVar(0, num_days, f'sum_of_shifts_{n}_{s}')
                shift_list = []
                for d in self._date_range:
                    try:
                        shift_list.append(self._works[(nurse, d, s)])
                    except Exception as e:
                        pass
                self.model.Add(sum_of_shifts[(nurse, s)] == sum(shift_list))
                                
        for s in self._shifts:
            try:
                min_fair_shift = self.model.NewIntVar(0, num_days, f'min_fair_shift_{s}')
                max_fair_shift = self.model.NewIntVar(0, num_days, f'max_fair_shift_{s}')
                self.model.AddMinEquality(min_fair_shift, [sum_of_shifts[(nurse, s)] for nurse in self._nurses])
                self.model.AddMaxEquality(max_fair_shift, [sum_of_shifts[(nurse, s)] for nurse in self._nurses]) 

                self.model.Add(max_fair_shift - min_fair_shift <= difference)

            except Exception as e:
                pass
            
################################################################################################
#########     FAIRNESS---------COLLECT WEEKEND/PREWEEKEND PENALTY    ###########################
################################################################################################
            
    def collect_penalty(self):
        self.weekend_penalty = {}
        self.pre_weekend_penalty = {}
        self.ams_penalty = {}
        for n, s, d in self._prev_roster:
            self.apply_penalty_weekend(n,s,d)

    def apply_penalty(self, data, nurse, shift):
        data[nurse].setdefault('oncall', 0)
        data[nurse].setdefault('notoncall', 0)
        data[nurse].setdefault('ams', 0)
        if shift in [n, no]:
            data[nurse]['oncall'] += 10
        elif shift == ams:
            data[nurse]['ams'] += 3
        else:
            data[nurse]['notoncall'] += 50
    
    def pre_weekend(self, d):
        return d + datetime.timedelta(days=1)
    
    def apply_penalty_weekend(self, nurse, shift, day):
        self.weekend_penalty.setdefault(nurse, {})
        self.pre_weekend_penalty.setdefault(nurse, {})
        pre_day = self.pre_weekend(day)
        if utils.is_weekend(day) or malaysia.is_holiday(day):
            self.apply_penalty(self.weekend_penalty, nurse, shift)
        elif utils.is_weekend(pre_day) or malaysia.is_holiday(pre_day):
            self.apply_penalty(self.pre_weekend_penalty, nurse, shift)
        elif shift == ams:
            self.apply_penalty(self.pre_weekend_penalty, nurse, shift)
            
################################################################################################
#########     FAIRNESS---------IMPLEMENT WEEKEND/PREWEEKEND PENALTY    #########################
################################################################################################
    
    def implement_model_penalty(self, nurse, works, name, penalty_data, penalty):
        assert penalty > 0, "penalty should be more than 0"
        trans_var = self.model.NewBoolVar(name)
        works.append(trans_var)
        self.model.AddBoolOr(works)
        self.obj_bool_vars_min.append(trans_var)
        self.obj_bool_coeffs_min.append(penalty+penalty_data)
    
    def implement_wkend_hold_prehold_ams_penalty(self):
        for nurse in self._nurses:
            name1 = f'weekend penalty {nurse}, shift {n} and {no}'
            name2 = f'weekend penalty {nurse}, shift not {n} and {no}'
            name3 = f'preweekend penalty {nurse}, shift {n} and {no}'
            name4 = f'preweekend penalty {nurse}, shift not {n} and {no}'
            name5 = f'ams penalty {nurse}'
                        
            worked_night_weekend = [self._works[(nurse, s, d)] for s in self._shifts if s in [n, no] for d in self._date_range if utils.is_weekend(d) or malaysia.is_holiday(d)]
            worked_weekend = [self._works[(nurse, s, d)] for s in self._shifts if s not in [n, no] for d in self._date_range if utils.is_weekend(d) or malaysia.is_holiday(d)]
            worked_night_preweekend = [self._works[(nurse, s, d)] for s in self._shifts if s in [n, no] for d in self._date_range if utils.is_weekend(self.pre_weekend(d)) or malaysia.is_holiday(self.pre_weekend(d))]
            worked_preweekend = [self._works[(nurse, s, d)] for s in self._shifts if s not in [n, no] for d in self._date_range if utils.is_weekend(self.pre_weekend(d)) or malaysia.is_holiday(self.pre_weekend(d))]
            worked_ams = [self._works[(nurse, ams, d)] for d in self._date_range if utils.is_weekday(d) or not malaysia.is_holiday(d)]
            
            self.implement_model_penalty(nurse, worked_night_weekend, name1, self.weekend_penalty[nurse]['oncall'], 10)
            self.implement_model_penalty(nurse, worked_weekend, name2, self.weekend_penalty[nurse]['notoncall'], 5)
            self.implement_model_penalty(nurse, worked_night_preweekend, name3, self.pre_weekend_penalty[nurse]['oncall'], 10)
            self.implement_model_penalty(nurse, worked_preweekend, name4, self.pre_weekend_penalty[nurse]['notoncall'], 5)
            self.implement_model_penalty(nurse, worked_ams, name5, self.pre_weekend_penalty[nurse]['ams'], 3)

In [5]:
m = Model(nurses, roles, shifts, leaves, date_range, prev_roster, prev_roster_date_range)
m.requests()
m.previous_roster(m._prev_roster)
m.off_day_per_week(1)
m.number_workers_per_shift()
m.handle_shift_transition(shift_transition)
m.handle_night_shift_sequences()
# m.handle_night_shift_sequences_mode2()
# m.handle_night_shift_sequences_mode3()
# m.handle_slot_sequence_constraints([(n, 2)])
m.implement_sum_constraint(sum_constraints)
m.sum_seniority(roles)
m.fairness_allocation()
m.implement_wkend_hold_prehold_ams_penalty()
m.maximize()
m.minimize()

print('Title: ')
m.solve()



Title: 
Solution 0, time = 0.17 s, objective = 296
Solution 1, time = 0.17 s, objective = 250
Solution 2, time = 0.22 s, objective = 240
Solution 3, time = 0.23 s, objective = 230
Solution 4, time = 0.23 s, objective = 220
Solution 5, time = 0.23 s, objective = 210
Solution 6, time = 0.24 s, objective = 200
Solution 7, time = 0.25 s, objective = 190
Solution 8, time = 0.25 s, objective = 180
Solution 9, time = 0.27 s, objective = 170
Solution 10, time = 0.30 s, objective = 160
Solution 11, time = 0.30 s, objective = 150
Solution 12, time = 0.32 s, objective = 140

                27-06	28-06	29-06	30-06	01-07	02-07	03-07	04-07	05-07	06-07	07-07	08-07	09-07	10-07	11-07	12-07	13-07	14-07	15-07	16-07	17-07	
Azatuliana:   	AM	AMS	AM	AM	AM	AL	DO	PM	AM	AMS	NO	PM	PM	DO	N	NO	PM	AMS	NO	DO	PM	
Fatehah:   	N	N	NO	PM	PM	PM	DO	AM	AM	PM	AM	AMS	NO	DO	AMS	PM	AM	PM	PM	AM	DO	
Fatimah:   	AL	AL	AMS	AM	AM	AM	DO	AM	N	NO	PM	PM	DO	AM	DO	AM	AMS	NO	PM	N	NO	
Fazilawati:   	NO	DO	PM	AL	PM	AL	DO	AMS	NO	PM	N	NO	AM

In [6]:

class Stats:
    def __init__(self, solution, shift_transitions, date_range, shift_list):
        self.solution = solution
        self.nurses = list(solution.keys())
        self.date_range = date_range
        self.shift_list = shift_list
        self.shift_transitions = shift_transitions
        self.max_or_always_transition_correct = 0
        self.max_or_always_transition_wrong = 0
        self.never_transition_correct = 0
        self.never_transition_wrong = 0
        self.details = {}
        self.post_night_rest = {}
        self.df_solution = pd.DataFrame(self.solution).T.rename(columns={n:self.date_range[n] for n in range(len(self.date_range))})
        self.count_transition_violation_for_all()
        self.df_shift_count_each_day = self.shift_count_each_day()
        self.df_shift_count_each_nurse = self.shift_count_each_nurse()
        self.df_gap_between_shift_each_nurse = self.gap_between_shift_each_nurse()
        self.get_average_post_night_rest()
        
    def window(self, arr, k):
        for i in range(len(arr)-k+1):
            yield arr[i:i+k]

    def count_transition_violation_for_all(self):
        for name, shifts in self.solution.items():
            try:
                self.count_transition_violation_for_each_transition(shifts)
            except Exception as e:
                pass
                
    def count_transition_violation_for_each_transition(self, shifts):
        for transition_rule, strategy, cost in self.shift_transitions:
            self.check_shift_by_window(shifts, transition_rule, strategy)
            
    def check_shift_by_window(self, shifts, rule, strategy):
        window_size = list(rule.keys())[-1] + 1
        for window in self.window(shifts, window_size):
            self.transition_rule(window, rule, strategy)
            
    def transition_rule(self, window, rule, strategy):
        assert len(rule) == 2, f"{rule} has odd length"
        if rule[0] == window[0] and strategy in ["always", 'max'] and (window[0], window[-1]) == (rule[0], rule[1]):
            self.max_or_always_transition_correct += 1
            self.write_results(rule[0], rule[1], 'correct always/max')
            
        elif rule[0] == window[0] and strategy in ["always", 'max'] and (window[0], window[-1]) != (rule[0], rule[1]):
            self.max_or_always_transition_wrong += 1
            self.write_results(rule[0], rule[1], 'wrong always/max')
        # else:
        #     self.max_or_always_transition_correct += 1
        #     self.write_results(rule[0], rule[1], 'correct always/max')
            
        if rule[0] == window[0] and strategy in ["never"] and (window[0], window[-1]) == (rule[0], rule[1]):
            self.never_transition_wrong += 1
            self.write_results(rule[0], rule[1], 'wrong never')
            
        elif rule[0] == window[0] and strategy in ["never"] and (window[0], window[-1]) != (rule[0], rule[1]):
            self.never_transition_correct += 1
            self.write_results(rule[0], rule[1], 'correct never')
            
    def write_results(self, rule1, rule2, strat):
        name = f'{rule1} to {rule2} {strat}'
        self.details.setdefault(name, 0)
        self.details[name] += 1
        
    
    # def count_shift_for_each_nurse(self):
    #     self.shift_counts = {}
    #     for name, shifts in self.solution.items():
    #         self.shift_counts.setdefault(name, {})
    #         for shift in self.shift_list:
    #             shift_count = shifts.count(shift)
    #             self.shift_counts[name][shift] = shift_count
    # Shift distribution for each shift
    # Rest after night shift sequence
    
    def shift_count_each_day(self):
        data = {}
        for day in self.date_range:
            data[day] = {}
            for k, v in self.df_solution[day].value_counts(sort=False).items():
                data[day][k] = v
        return pd.DataFrame(data).fillna(0).astype(int)
        
    def shift_count_each_nurse(self):
        data2 = {}
        for nurse in self.nurses:
            data2[nurse] = {}
            for k, v in self.df_solution.T[nurse].value_counts(sort=False).items():
                data2[nurse][k] = v
        return pd.DataFrame(data2).fillna(0).astype(int)
        
    def nth_index(self, iterable, value, n):
        matches = (idx for idx, val in enumerate(iterable) if val == value)
        return next(islice(matches, n-1, n), None)

    def gap_between_shift(self, nurse, shift):
        num_of_shift = self.df_shift_count_each_nurse.loc[shift][nurse]
        return self.nth_index(self.solution[nurse], shift, num_of_shift) - self.nth_index(self.solution[nurse], shift, 1)

    def gap_between_shift_each_nurse(self):
        df = self.df_shift_count_each_nurse.copy()
        for nurse in self.nurses:
            for shift in self.shift_list:
                try:
                    gap = self.gap_between_shift(nurse, shift)
                    if gap == 1:
                        df.loc[shift][nurse] = 0
                    else:
                        df.loc[shift][nurse] = gap
                except:
                    df.loc[shift][nurse] = None
        return df
    
    def get_average_post_night_rest_for_each_nurse(self, nurse):
        arr = self.solution[nurse]
        rest_after_night = datetime.timedelta(days=0)
        count = 0
        for index, item in enumerate(self.window(arr, 4)):
            combine = datetime.datetime.combine
            timedelta_ = datetime.timedelta
            d1,d2,d3,d4 = item
            if (d1, d2) == (n, no):
                count += 1
                d2_end = combine(self.date_range[index+1],shift_timings[item[1]]['end'])
                if d3 in [al, do]:
                    d3_start = combine(self.date_range[index+2]+timedelta_(days=1),shift_timings[d4]['start'])
                else:
                    d3_start = combine(self.date_range[index+2],shift_timings[d3]['start'])

                rest_after_night += (d3_start-d2_end)
        
        try:
            self.post_night_rest[nurse] = utils.format_timedelta(rest_after_night/count)
        except ZeroDivisionError:
            self.post_night_rest[nurse] = 0
            
    def get_average_post_night_rest(self):
        for n in self.nurses:
            self.get_average_post_night_rest_for_each_nurse(n)

In [7]:
shift_transition.append(({0: ams, 1: no}, "never", 0))
d = Stats(m.solution, shift_transition, m._total_date_range, shifts)
d.df_solution

Unnamed: 0,2022-06-27,2022-06-28,2022-06-29,2022-06-30,2022-07-01,2022-07-02,2022-07-03,2022-07-04,2022-07-05,2022-07-06,...,2022-07-08,2022-07-09,2022-07-10,2022-07-11,2022-07-12,2022-07-13,2022-07-14,2022-07-15,2022-07-16,2022-07-17
Azatuliana,AM,AMS,AM,AM,AM,AL,DO,PM,AM,AMS,...,PM,PM,DO,N,NO,PM,AMS,NO,DO,PM
Fatehah,N,N,NO,PM,PM,PM,DO,AM,AM,PM,...,AMS,NO,DO,AMS,PM,AM,PM,PM,AM,DO
Fatimah,AL,AL,AMS,AM,AM,AM,DO,AM,N,NO,...,PM,DO,AM,DO,AM,AMS,NO,PM,N,NO
Fazilawati,NO,DO,PM,AL,PM,AL,DO,AMS,NO,PM,...,NO,AM,DO,DO,PM,AMS,N,NO,PM,AM
Mimi,AM,AMS,AM,AM,AM,AL,DO,PM,PM,N,...,NO,AM,DO,PM,AMS,NO,AM,N,NO,DO
Nuraimi,PM,PM,PM,AMS,AM,AM,N,AM,AM,AM,...,PM,DO,N,PM,AM,AM,AM,AMS,NO,DO
Sariah,PM,PM,N,N,NO,DO,PM,NO,PM,AM,...,N,NO,DO,PM,PM,N,NO,PM,AM,DO
Sitisakinah,PM,PM,PM,PM,PM,DO,AM,N,NO,PM,...,AMS,NO,DO,AMS,NO,PM,PM,AM,DO,N
Tina,AMS,AM,AM,AM,AMS,AM,DO,NO,AMS,NO,...,AM,AM,PM,AM,N,NO,PM,AMS,NO,DO
Wahidah,AMS,AM,AM,AM,N,N,NO,PM,PM,AM,...,AM,N,DO,AM,AMS,PM,AM,AM,AM,DO


In [8]:
d.df_shift_count_each_nurse

Unnamed: 0,Azatuliana,Fatehah,Fatimah,Fazilawati,Mimi,Nuraimi,Sariah,Sitisakinah,Tina,Wahidah
AM,5,5,6,2,6,8,2,2,7,10
AMS,3,2,2,2,3,2,1,2,4,2
AL,1,0,2,2,1,0,0,0,0,0
DO,3,3,3,4,3,2,3,3,3,2
PM,5,7,3,5,3,6,7,9,2,3
NO,3,2,3,4,3,1,4,3,4,1
N,1,2,2,2,2,2,4,2,1,3


In [9]:
d.df_gap_between_shift_each_nurse

Unnamed: 0,Azatuliana,Fatehah,Fatimah,Fazilawati,Mimi,Nuraimi,Sariah,Sitisakinah,Tina,Wahidah
AM,8,12,12,8,17,13,10,12,13,18
AMS,16,3,14,9,14,15,0,3,18,15
AL,1,0,2,2,1,0,0,0,0,0
DO,3,3,3,4,3,2,3,3,3,2
PM,13,15,8,17,7,14,18,17,4,9
NO,3,2,3,4,3,1,4,3,4,1
N,0,0,11,7,9,7,14,13,0,8


In [10]:
d.post_night_rest

{'Azatuliana': '1 days, 7 hours, 0 mins',
 'Fatehah': '1 days, 7 hours, 0 mins',
 'Fatimah': '1 days, 7 hours, 0 mins',
 'Fazilawati': '1 days, 3 hours, 30 mins',
 'Mimi': 0,
 'Nuraimi': 0,
 'Sariah': '1 days, 23 hours, 0 mins',
 'Sitisakinah': '1 days, 7 hours, 0 mins',
 'Tina': '1 days, 7 hours, 0 mins',
 'Wahidah': '1 days, 7 hours, 0 mins'}

In [11]:
print(f'always/max transition correct {d.max_or_always_transition_correct}')
print(f'always/max transition wrong {d.max_or_always_transition_wrong}')
print(f'never transition correct {d.never_transition_correct}')
print(f'never transition wrong {d.never_transition_wrong}')

always/max transition correct 13
always/max transition wrong 7
never transition correct 27
never transition wrong 0
