In [1]:
import sys

In [2]:
sys.path.append('C:\\Users\\gertj\\OneDrive\\Bureaublad\\School\\Master\\Thesis\\RobustConAcq\\RobustConAcq')

In [3]:
from pycona import *
import cpmpy as cp

In [4]:
from pycona.active_algorithms import RobustAcq
from pycona.answering_queries import MisclassifyingOracle

In [5]:
def construct_nurse_rostering(shifts_per_day, num_days, num_nurses, nurses_per_shift):

    # Define the variables
    roster_matrix = cp.intvar(1, num_nurses, shape=(num_days, shifts_per_day, nurses_per_shift), 
                              name="shifts")


    # Create the language:
    AV = absvar(2)  # create abstract vars - as many as maximum arity
    # create abstract relations using the abstract vars
    lang = [AV[0] == AV[1], AV[0] != AV[1], AV[0] < AV[1], AV[0] > AV[1], AV[0] >= AV[1], AV[0] <= AV[1]]

    instance = ProblemInstance(variables=roster_matrix, language=lang)

    return instance

In [6]:
def construct_nurse_rostering_with_oracle(shifts_per_day, num_days, num_nurses, nurses_per_shift):
    
    inst = construct_nurse_rostering(shifts_per_day, num_days, num_nurses, nurses_per_shift)
    
    # Create an oracle for simulation, by defining the ground truth problem
    ## Define the variables
    roster_matrix = inst.variables
    
    ## Define the list of (fixed-arity/decomposed) constraints
    C_T = []

    ## Constraint: Each shift in a day must be assigned to a different nurse
    for day in range(num_days):
        C_T += cp.AllDifferent(roster_matrix[day, ...]).decompose()   # ... means all remaining dimensions

    ## Constraint: The last shift of a day cannot have the same nurse as the first shift of the next day
    for day in range(num_days - 1):
        C_T += cp.AllDifferent(roster_matrix[day, shifts_per_day - 1], roster_matrix[day + 1, 0]).decompose()

    oracle = MisclassifyingOracle(C_T, misclassification_rate=0)

    return inst, oracle

In [7]:
instance, oracle = construct_nurse_rostering_with_oracle(3, 2, 8, 2)

In [8]:
#env = ProbaActiveCAEnv(find_scope=FindScope())
ca = RobustAcq()
learned_instance = ca.learn(instance, oracle=oracle, verbose=1)

396
0
396
q1
noisy MQ
random: 0.9180495277294194
.q1 scope
findscope2
....before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
.before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
q1 cl
L393
2
396
q1
noisy MQ
random: 0.6965826445267628
.q1 scope
findscope2
...before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
..before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
.before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
q1 cl
L388
6
396
q1
noisy MQ
random: 0.7176520615399868
.q1 scope
findscope2
...before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
.before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
.before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
.before update Br type is: <class 'set'>
after update Br type is: <class 'set'>
q1 cl
L382
11
396
q1
noisy MQ
random: 0.51190455

AttributeError: 'list' object has no attribute 'update'

In [None]:
learned_instance.cl

In [None]:
len(learned_instance.cl)