# TA Scheduling

## Naive Programming

In [None]:
from ortools.sat.python import cp_model

def schedule_tas(availability, ta_requirements, shift_requirements):
    """
    Schedule TAs for shifts based on availability and constraints.
    
    Args:
        availability: A 2D list where availability[i][j] is:
                      - 1: Desired
                      - 0: Undesired
                      - -1: Unavailable
        ta_requirements: A list where ta_requirements[i] is the number of shifts TA i must work.
        shift_requirements: A list where shift_requirements[j] is the number of TAs needed for shift j.

    Returns:
        A dictionary mapping TA indices to their assigned shifts.
    """
    num_tas = len(availability)
    num_shifts = len(availability[0])
    
    # Create the CP-SAT model
    model = cp_model.CpModel()
    
    # Decision variables: x[i][j] is 1 if TA i is assigned to shift j, else 0
    x = {}
    for i in range(num_tas):
        for j in range(num_shifts):
            x[i, j] = model.NewBoolVar(f'x[{i}][{j}]')
    
    # Constraint: TAs must meet their shift requirements
    for i in range(num_tas):
        model.Add(sum(x[i, j] for j in range(num_shifts)) == ta_requirements[i])
    
    # Constraint: Shifts must meet staffing requirements
    for j in range(num_shifts):
        model.Add(sum(x[i, j] for i in range(num_tas)) == shift_requirements[j])
    
    # Constraint: Respect TA availability
    for i in range(num_tas):
        for j in range(num_shifts):
            if availability[i][j] == -1:  # Unavailable
                model.Add(x[i, j] == 0)
    
    # Objective: Maximize TA satisfaction
    objective_terms = []
    for i in range(num_tas):
        for j in range(num_shifts):
            if availability[i][j] == 1:  # Desired
                objective_terms.append(10 * x[i, j])  # High weight for desired
            elif availability[i][j] == 0:  # Undesired
                objective_terms.append(-1 * x[i, j])  # Low weight for undesired
    model.Maximize(sum(objective_terms))
    
    # Solve the model
    solver = cp_model.CpSolver()
    status = solver.Solve(model)
    
    # Extract the solution
    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
        schedule = {}
        for i in range(num_tas):
            schedule[i] = [j for j in range(num_shifts) if solver.Value(x[i, j]) == 1]
        return schedule
    else:
        return None

# Example Input
availability = [
    [1, 0, -1, 1, 1],   # TA 0: Desired, Undesired, Unavailable, Desired, Desired
    [1, 1, 0, -1, 1],   # TA 1: Desired, Desired, Undesired, Unavailable, Desired
    [-1, 1, 1, 0, 1],   # TA 2: Unavailable, Desired, Desired, Undesired, Desired
]
ta_requirements = [2, 3, 2]  # TA 0: 2 shifts, TA 1: 1 shift, TA 2: 2 shifts
shift_requirements = [2, 1, 1, 1, 2]  # Shifts: 2 TAs, 1 TA, 1 TA, 1 TA, 2 TAs

# Solve
solution = schedule_tas(availability, ta_requirements, shift_requirements)
if solution:
    for ta, shifts in solution.items():
        print(f"TA {ta} assigned to shifts: {shifts}")
else:
    print("No feasible solution found.")


In [None]:
from ortools.sat.python import cp_model

def schedule_tas_debug(availability, ta_requirements, shift_requirements):
    num_tas = len(availability)
    num_shifts = len(availability[0])
    
    print("Availability Matrix:")
    for i in range(num_tas):
        print(f"TA {i}: {availability[i]}")
    
    print("\nTA Requirements:", ta_requirements)
    print("Shift Requirements:", shift_requirements)
    
    # Create the CP-SAT model
    model = cp_model.CpModel()
    
    # Decision variables: x[i][j] is 1 if TA i is assigned to shift j, else 0
    x = {}
    for i in range(num_tas):
        for j in range(num_shifts):
            x[i, j] = model.NewBoolVar(f'x[{i}][{j}]')
    
    # Constraint: TAs must meet their shift requirements
    for i in range(num_tas):
        model.Add(sum(x[i, j] for j in range(num_shifts)) == ta_requirements[i])
        print(f"Added TA {i} shift requirement: {ta_requirements[i]}")
    
    # Constraint: Shifts must meet staffing requirements
    for j in range(num_shifts):
        model.Add(sum(x[i, j] for i in range(num_tas)) == shift_requirements[j])
        print(f"Added shift {j} staffing requirement: {shift_requirements[j]}")
    
    # Constraint: Respect TA availability
    for i in range(num_tas):
        for j in range(num_shifts):
            if availability[i][j] == -1:  # Unavailable
                model.Add(x[i, j] == 0)
                print(f"TA {i} cannot work shift {j}")
    
    # Objective: Maximize TA satisfaction
    objective_terms = []
    for i in range(num_tas):
        for j in range(num_shifts):
            if availability[i][j] == 1:  # Desired
                objective_terms.append(10 * x[i, j])  # High weight for desired
            elif availability[i][j] == 0:  # Undesired
                objective_terms.append(-1 * x[i, j])  # Low weight for undesired
    model.Maximize(sum(objective_terms))
    
    # Solve the model
    solver = cp_model.CpSolver()
    status = solver.Solve(model)
    
    # Extract the solution
    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
        schedule = {}
        for i in range(num_tas):
            schedule[i] = [j for j in range(num_shifts) if solver.Value(x[i, j]) == 1]
        return schedule
    else:
        print("\nNo feasible solution found. Possible reasons:")
        print("- Check if total TA shift requirements match total shift needs.")
        print("- Ensure there are enough available TAs for each shift.")
        print("- Check if constraints are overly strict.")
        return None

# Example Input
availability = [
    [1, 0, -1, 1, 1],   # TA 0: Desired, Undesired, Unavailable, Desired, Desired
    [1, 1, 0, -1, 1],   # TA 1: Desired, Desired, Undesired, Unavailable, Desired
    [-1, 1, 1, 0, 1],   # TA 2: Unavailable, Desired, Desired, Undesired, Desired
]
ta_requirements = [2, 3, 2]  # TA 0: 2 shifts, TA 1: 1 shift, TA 2: 2 shifts
shift_requirements = [2, 1, 1, 1, 2]  # Shifts: 2 TAs, 1 TA, 1 TA, 1 TA, 2 TAs

# Solve
solution = schedule_tas_debug(availability, ta_requirements, shift_requirements)
if solution:
    print("\nFeasible Solution:")
    for ta, shifts in solution.items():
        print(f"TA {ta} assigned to shifts: {shifts}")
else:
    print("No feasible solution found.")


## Class-based Implementation

In [None]:
# TODO

# General Optimization Practice

## CP-SAT - Nurse Scheduling

In [2]:
from ortools.sat.python import cp_model

### Data Creation

In [None]:
num_nurses = 4
num_shifts = 3
num_days = 3
all_nurses = range(num_nurses)
all_shifts = range(num_shifts)
all_days = range(num_days)

### Create the model

In [None]:
model = cp_model.CpModel()

### Create the variables
The array defines assignments for shifts to nurses as follows: shifts[(n, d, s)] equals 1 if shift s is assigned to nurse n on day d, and 0 otherwise.

In [None]:
shifts = {}
for n in all_nurses:
    for d in all_days:
        for s in all_shifts:
            shifts[(n, d, s)] = model.new_bool_var(f"shift_n{n}_d{d}_s{s}")

### Code the constraints
Next, we show how to assign nurses to shifts subject to the following constraints:

Each shift is assigned to a single nurse per day.
Each nurse works at most one shift per day.

Here's the code that creates the first condition

In [None]:
for d in all_days:
    for s in all_shifts:
        model.add_exactly_one(shifts[(n, d, s)] for n in all_nurses)

The last line says that for each shift, the sum of the nurses assigned to that shift is 1.

Next, here's the code that requires that each nurse works at most one shift per day.

In [None]:
for n in all_nurses:
    for d in all_days:
        model.add_at_most_one(shifts[(n, d, s)] for s in all_shifts)

For each nurse, the sum of shifts assigned to that nurse is at most 1 ("at most" because a nurse might have the day off).

**Assign shifts evenly**

In [None]:
# Try to distribute the shifts evenly, so that each nurse works
# min_shifts_per_nurse shifts. If this is not possible, because the total
# number of shifts is not divisible by the number of nurses, some nurses will
# be assigned one more shift.
min_shifts_per_nurse = (num_shifts * num_days) // num_nurses
if num_shifts * num_days % num_nurses == 0:
    max_shifts_per_nurse = min_shifts_per_nurse
else:
    max_shifts_per_nurse = min_shifts_per_nurse + 1
for n in all_nurses:
    shifts_worked = []
    for d in all_days:
        for s in all_shifts:
            shifts_worked.append(shifts[(n, d, s)])
    model.add(min_shifts_per_nurse <= sum(shifts_worked))
    model.add(sum(shifts_worked) <= max_shifts_per_nurse)

Since there are num_shifts * num_days total shifts in the schedule period, you can assign at least (num_shifts * num_days) // num_nurses

shifts to each nurse, but some shifts may be left over. (Here // is the Python integer division operator, which returns the floor of the usual quotient.)

For the given values of num_nurses = 4, num_shifts = 3, and num_days = 3, the expression min_shifts_per_nurse has the value (3 * 3 // 4) = 2, so you can assign at least two shifts to each nurse. This is specified by the constraint (here in Python)

In [None]:
model.add(min_shifts_per_nurse <= sum(shifts_worked))

Since there are nine total shifts over the three-day period, there is one remaining shift after assigning two shifts to each nurse. The extra shift can be assigned to any nurse.

The final line (here in Python)

In [None]:
model.add(sum(shifts_worked) <= max_shifts_per_nurse)

### Update solver parameters

In [None]:
solver = cp_model.CpSolver()
solver.parameters.linearization_level = 0
# Enumerate all solutions.
solver.parameters.enumerate_all_solutions = True

### Register a callback function
You need to register a callback on the solver that will be called at each solution. (kind of optional)

In [None]:
class NursesPartialSolutionPrinter(cp_model.CpSolverSolutionCallback):
    """Print intermediate solutions."""

    def __init__(self, shifts, num_nurses, num_days, num_shifts, limit):
        cp_model.CpSolverSolutionCallback.__init__(self)
        self._shifts = shifts
        self._num_nurses = num_nurses
        self._num_days = num_days
        self._num_shifts = num_shifts
        self._solution_count = 0
        self._solution_limit = limit

    def on_solution_callback(self):
        self._solution_count += 1
        print(f"Solution {self._solution_count}")
        for d in range(self._num_days):
            print(f"Day {d}")
            for n in range(self._num_nurses):
                is_working = False
                for s in range(self._num_shifts):
                    if self.value(self._shifts[(n, d, s)]):
                        is_working = True
                        print(f"  Nurse {n} works shift {s}")
                if not is_working:
                    print(f"  Nurse {n} does not work")
        if self._solution_count >= self._solution_limit:
            print(f"Stop search after {self._solution_limit} solutions")
            self.stop_search()

    def solutionCount(self):
        return self._solution_count

# Display the first five solutions.
solution_limit = 5
solution_printer = NursesPartialSolutionPrinter(
    shifts, num_nurses, num_days, num_shifts, solution_limit
)

### Invoke the solver

In [None]:
solver.solve(model, solution_printer)

### Entire nurse scheduling program 

In [3]:
"""Example of a simple nurse scheduling problem."""
from ortools.sat.python import cp_model


def main() -> None:
    # Data.
    num_nurses = 4
    num_shifts = 3
    num_days = 3
    all_nurses = range(num_nurses)
    all_shifts = range(num_shifts)
    all_days = range(num_days)

    # Creates the model.
    model = cp_model.CpModel()

    # Creates shift variables.
    # shifts[(n, d, s)]: nurse 'n' works shift 's' on day 'd'.
    shifts = {}
    for n in all_nurses:
        for d in all_days:
            for s in all_shifts:
                shifts[(n, d, s)] = model.new_bool_var(f"shift_n{n}_d{d}_s{s}")

    # Each shift is assigned to exactly one nurse in the schedule period.
    for d in all_days:
        for s in all_shifts:
            model.add_exactly_one(shifts[(n, d, s)] for n in all_nurses)

    # Each nurse works at most one shift per day.
    for n in all_nurses:
        for d in all_days:
            model.add_at_most_one(shifts[(n, d, s)] for s in all_shifts)

    # Try to distribute the shifts evenly, so that each nurse works
    # min_shifts_per_nurse shifts. If this is not possible, because the total
    # number of shifts is not divisible by the number of nurses, some nurses will
    # be assigned one more shift.
    min_shifts_per_nurse = (num_shifts * num_days) // num_nurses
    if num_shifts * num_days % num_nurses == 0:
        max_shifts_per_nurse = min_shifts_per_nurse
    else:
        max_shifts_per_nurse = min_shifts_per_nurse + 1
    for n in all_nurses:
        shifts_worked = []
        for d in all_days:
            for s in all_shifts:
                shifts_worked.append(shifts[(n, d, s)])
        model.add(min_shifts_per_nurse <= sum(shifts_worked))
        model.add(sum(shifts_worked) <= max_shifts_per_nurse)

    # Creates the solver and solve.
    solver = cp_model.CpSolver()
    solver.parameters.linearization_level = 0
    # Enumerate all solutions.
    solver.parameters.enumerate_all_solutions = True

    class NursesPartialSolutionPrinter(cp_model.CpSolverSolutionCallback):
        """Print intermediate solutions."""

        def __init__(self, shifts, num_nurses, num_days, num_shifts, limit):
            cp_model.CpSolverSolutionCallback.__init__(self)
            self._shifts = shifts
            self._num_nurses = num_nurses
            self._num_days = num_days
            self._num_shifts = num_shifts
            self._solution_count = 0
            self._solution_limit = limit

        def on_solution_callback(self):
            self._solution_count += 1
            print(f"Solution {self._solution_count}")
            for d in range(self._num_days):
                print(f"Day {d}")
                for n in range(self._num_nurses):
                    is_working = False
                    for s in range(self._num_shifts):
                        if self.value(self._shifts[(n, d, s)]):
                            is_working = True
                            print(f"  Nurse {n} works shift {s}")
                    if not is_working:
                        print(f"  Nurse {n} does not work")
            if self._solution_count >= self._solution_limit:
                print(f"Stop search after {self._solution_limit} solutions")
                self.stop_search()

        def solutionCount(self):
            return self._solution_count

    # Display the first five solutions.
    solution_limit = 5
    solution_printer = NursesPartialSolutionPrinter(
        shifts, num_nurses, num_days, num_shifts, solution_limit
    )

    solver.solve(model, solution_printer)

    # Statistics.
    print("\nStatistics")
    print(f"  - conflicts      : {solver.num_conflicts}")
    print(f"  - branches       : {solver.num_branches}")
    print(f"  - wall time      : {solver.wall_time} s")
    print(f"  - solutions found: {solution_printer.solutionCount()}")


if __name__ == "__main__":
    main()

Solution 1
Day 0
  Nurse 0 does not work
  Nurse 1 works shift 0
  Nurse 2 works shift 1
  Nurse 3 works shift 2
Day 1
  Nurse 0 works shift 2
  Nurse 1 does not work
  Nurse 2 works shift 1
  Nurse 3 works shift 0
Day 2
  Nurse 0 works shift 2
  Nurse 1 works shift 1
  Nurse 2 works shift 0
  Nurse 3 does not work
Solution 2
Day 0
  Nurse 0 works shift 0
  Nurse 1 does not work
  Nurse 2 works shift 1
  Nurse 3 works shift 2
Day 1
  Nurse 0 does not work
  Nurse 1 works shift 2
  Nurse 2 works shift 1
  Nurse 3 works shift 0
Day 2
  Nurse 0 works shift 2
  Nurse 1 works shift 1
  Nurse 2 works shift 0
  Nurse 3 does not work
Solution 3
Day 0
  Nurse 0 works shift 0
  Nurse 1 does not work
  Nurse 2 works shift 1
  Nurse 3 works shift 2
Day 1
  Nurse 0 works shift 1
  Nurse 1 works shift 2
  Nurse 2 does not work
  Nurse 3 works shift 0
Day 2
  Nurse 0 works shift 2
  Nurse 1 works shift 1
  Nurse 2 works shift 0
  Nurse 3 does not work
Solution 4
Day 0
  Nurse 0 works shift 0
  Nurse 