In [1]:
import numpy as np
np.random.seed(0)
from scipy.stats import beta as sp_beta

import matplotlib.pyplot as plt
plt.style.use('seaborn-darkgrid')
import matplotlib
matplotlib.rcParams.update({'font.size': 12})

import sys
sys.path.insert(0, '../../..')
import assignment
import leximin_assignment
import matrix_gen
import kde_utils

from tqdm import tqdm

import warnings

In [22]:
N = 10
N_INTVS = 5
N_EXPERIMENTS = 500  # 10 and 500

CAPACITIES = (np.ones((N_INTVS,)) * (N // N_INTVS)).astype(int)

In [12]:
params = (0, 1)
uniform_distribution = np.random.uniform

matrix_generator = matrix_gen.MatrixGenerator(
    uniform_distribution,
    params,
    N, N_INTVS
)

for _ in tqdm(range(N_EXPERIMENTS)):
    cost_matrix = matrix_generator.get_new_matrix(sort_rows=True)
    
    # Leximin assignment
    lex_assigner = leximin_assignment.LeximinAssignmentHelperV3(
        cost_matrix, CAPACITIES
    )
    with warnings.catch_warnings():  # temporarily suspense warnings
        warnings.simplefilter('ignore')
        lex_assignments = lex_assigner.solve()
    
    lowest_agents = np.argsort(cost_matrix[:, -1])[: CAPACITIES[-1]]
    if np.any(lex_assignments[lowest_agents] != N_INTVS - 1):
        print('here')

100%|██████████| 500/500 [06:13<00:00,  1.34it/s]


In [19]:
params = (0, 1)
uniform_distribution = np.random.uniform

matrix_generator = matrix_gen.MatrixGenerator(
    uniform_distribution,
    params,
    N, N_INTVS
)

for _ in tqdm(range(N_EXPERIMENTS)):
    cost_matrix = matrix_generator.get_new_matrix(sort_rows=True)
    
    # Leximin assignment
    lex_assigner = leximin_assignment.LeximinAssignmentHelperV3(
        cost_matrix, CAPACITIES
    )
    with warnings.catch_warnings():  # temporarily suspense warnings
        warnings.simplefilter('ignore')
        lex_assignments = lex_assigner.solve()
    
    running_agent_indices = [i for i in range(N)]
    
    lowest_agents = np.argsort(cost_matrix[:, -1])[: CAPACITIES[-1]]
    if np.any(lex_assignments[lowest_agents] != N_INTVS - 1):
        print('here')
        
    for j in range(N_INTVS - 1, -1, -1):
        lowest_cost_agents = np.argsort(cost_matrix[running_agent_indices, j])[: CAPACITIES[j]]
        if np.any(lex_assignments[lowest_cost_agents] != j):
            print(cost_matrix)
            print(lex_assignments)
            break
        
        running_agent_indices = np.delete(running_agent_indices, lowest_cost_agents)

  0%|          | 1/500 [00:00<06:24,  1.30it/s]

[[0.27442 0.27463 0.35402 0.42406 0.48464]
 [0.06097 0.24357 0.41644 0.51649 0.59298]
 [0.29731 0.50768 0.54179 0.6958  0.8338 ]
 [0.15591 0.35308 0.39609 0.67896 0.90309]
 [0.10774 0.13896 0.74052 0.82974 0.89903]
 [0.03496 0.26501 0.47816 0.50891 0.57938]
 [0.26664 0.34993 0.37069 0.65123 0.74693]
 [0.6975  0.79111 0.82104 0.82383 0.86565]
 [0.14776 0.3024  0.43181 0.60276 0.69544]
 [0.26491 0.34913 0.47844 0.80019 0.89788]
 [0.06314 0.22783 0.23258 0.81215 0.93338]
 [0.23497 0.65784 0.67748 0.8905  0.91031]
 [0.06348 0.09528 0.2545  0.34148 0.5271 ]
 [0.04665 0.06421 0.23681 0.54979 0.75312]
 [0.4707  0.48667 0.48817 0.48895 0.67338]
 [0.10751 0.16744 0.29397 0.56775 0.57325]
 [0.48173 0.60932 0.80684 0.88329 0.88485]
 [0.08827 0.13083 0.16572 0.30555 0.82309]
 [0.00136 0.19157 0.5021  0.53852 0.8571 ]
 [0.24109 0.81816 0.91908 0.92124 0.98494]
 [0.12549 0.59831 0.76958 0.87195 0.93324]
 [0.07249 0.07952 0.48612 0.58243 0.69666]
 [0.15811 0.2724  0.37572 0.39855 0.49687]
 [0.04256 0

  0%|          | 2/500 [00:01<06:25,  1.29it/s]

[[0.33308 0.54508 0.61849 0.6388  0.82875]
 [0.16342 0.55475 0.76653 0.80252 0.96463]
 [0.02909 0.4955  0.56314 0.74639 0.80594]
 [0.41231 0.70852 0.80084 0.87732 0.96349]
 [0.41512 0.49555 0.50871 0.74029 0.99792]
 [0.1147  0.23305 0.29254 0.40949 0.66232]
 [0.18053 0.23934 0.32536 0.7095  0.85559]
 [0.18106 0.25205 0.40453 0.508   0.75479]
 [0.0103  0.11035 0.39151 0.41134 0.83379]
 [0.19992 0.71547 0.78845 0.79683 0.89101]
 [0.03432 0.04883 0.39272 0.58907 0.95087]
 [0.13772 0.29759 0.33902 0.6718  0.84514]
 [0.03696 0.32496 0.37964 0.48087 0.53149]
 [0.17392 0.35912 0.50274 0.87008 0.88457]
 [0.1898  0.32048 0.7319  0.82556 0.9827 ]
 [0.13207 0.16772 0.38781 0.4947  0.79915]
 [0.0721  0.17239 0.47093 0.57422 0.70995]
 [0.01217 0.43788 0.4855  0.66504 0.8156 ]
 [0.02497 0.17396 0.20109 0.39775 0.71978]
 [0.18633 0.26095 0.30779 0.664   0.71519]
 [0.10088 0.28917 0.43723 0.9626  0.98903]
 [0.19513 0.72395 0.82568 0.83434 0.98309]
 [0.12011 0.2278  0.44243 0.88389 0.98703]
 [0.15091 0

  1%|          | 3/500 [00:02<06:32,  1.27it/s]

[[0.24973 0.38658 0.52054 0.64248 0.8638 ]
 [0.40946 0.46135 0.46815 0.70279 0.73187]
 [0.12462 0.14736 0.46435 0.53623 0.57927]
 [0.04591 0.17586 0.49875 0.52951 0.82797]
 [0.3061  0.4699  0.6321  0.77517 0.99941]
 [0.10312 0.64979 0.69278 0.7788  0.90481]
 [0.05328 0.33062 0.54068 0.67721 0.99675]
 [0.01689 0.26523 0.37802 0.4813  0.54483]
 [0.22035 0.39724 0.42521 0.46668 0.9039 ]
 [0.18669 0.32074 0.43022 0.63138 0.74422]
 [0.04214 0.05614 0.35449 0.88778 0.98959]
 [0.2011  0.55996 0.58693 0.68604 0.93002]
 [0.01228 0.08881 0.20005 0.62537 0.83978]
 [0.28072 0.40548 0.7629  0.90453 0.97969]
 [0.14386 0.24346 0.34832 0.71802 0.81305]
 [0.36754 0.52306 0.56092 0.62512 0.92557]
 [0.66513 0.66797 0.77042 0.83887 0.88483]
 [0.05973 0.60129 0.60704 0.87739 0.93522]
 [0.2875  0.29422 0.33786 0.7774  0.91816]
 [0.09519 0.38623 0.68934 0.70079 0.72585]
 [0.02987 0.46544 0.66563 0.75306 0.86633]
 [0.1085  0.15817 0.25193 0.46686 0.90029]
 [0.33087 0.47653 0.50563 0.61444 0.87612]
 [0.23702 0

  1%|          | 4/500 [00:03<06:33,  1.26it/s]

[[0.01409 0.3995  0.68713 0.75005 0.79566]
 [0.03235 0.47998 0.63284 0.68731 0.87729]
 [0.17072 0.27612 0.53504 0.87911 0.93202]
 [0.54526 0.67909 0.83415 0.95055 0.99648]
 [0.2152  0.29766 0.31555 0.3783  0.71572]
 [0.36035 0.37565 0.5317  0.68226 0.95745]
 [0.08859 0.14826 0.22061 0.38317 0.42588]
 [0.06089 0.11362 0.34018 0.34416 0.51005]
 [0.3276  0.68913 0.70458 0.87135 0.96078]
 [0.08679 0.65542 0.78066 0.85489 0.88615]
 [0.30155 0.58731 0.59795 0.87513 0.9287 ]
 [0.01015 0.18159 0.319   0.45443 0.52862]
 [0.06642 0.2833  0.62132 0.95989 0.99961]
 [0.17201 0.22978 0.32958 0.57246 0.80363]
 [0.47792 0.65587 0.67896 0.92065 0.95532]
 [0.10564 0.31646 0.40763 0.48756 0.92792]
 [0.13297 0.27252 0.48924 0.7484  0.84075]
 [0.06737 0.19947 0.44615 0.53508 0.97361]
 [0.21501 0.23361 0.24646 0.48882 0.6604 ]
 [0.12429 0.52699 0.74347 0.83298 0.90897]
 [0.21538 0.27789 0.30418 0.54058 0.75918]
 [0.32462 0.37425 0.50747 0.64237 0.7322 ]
 [0.0801  0.31937 0.62854 0.84684 0.88535]
 [0.36621 0

  1%|          | 5/500 [00:04<06:41,  1.23it/s]

[[0.01626 0.50712 0.63791 0.7186  0.72206]
 [0.39543 0.61154 0.64565 0.84036 0.99588]
 [0.14008 0.17194 0.27571 0.67644 0.92905]
 [0.0082  0.18746 0.3572  0.69363 0.8301 ]
 [0.05171 0.13002 0.5656  0.63964 0.85656]
 [0.43137 0.78941 0.8619  0.91274 0.95807]
 [0.04436 0.0838  0.2279  0.63981 0.86079]
 [0.11192 0.23046 0.29282 0.43522 0.51355]
 [0.0423  0.17529 0.57336 0.77576 0.9321 ]
 [0.18079 0.21876 0.21954 0.42219 0.90889]
 [0.03482 0.09306 0.64181 0.87143 0.98391]
 [0.02665 0.32269 0.70507 0.75989 0.82782]
 [0.16109 0.54289 0.56637 0.6935  0.95065]
 [0.03506 0.188   0.26515 0.54138 0.65324]
 [0.17285 0.47571 0.59605 0.63893 0.66343]
 [0.01012 0.30333 0.43583 0.52164 0.93232]
 [0.37897 0.46641 0.4984  0.6625  0.75846]
 [0.20913 0.71169 0.78565 0.78703 0.86168]
 [0.18234 0.22976 0.25486 0.55701 0.79702]
 [0.22583 0.29174 0.57061 0.86184 0.99275]
 [0.08824 0.41423 0.60236 0.64784 0.77442]
 [0.01172 0.34247 0.4691  0.70387 0.8439 ]
 [0.10425 0.27116 0.33866 0.574   0.61304]
 [0.07005 0

  1%|          | 6/500 [00:04<06:42,  1.23it/s]

[[0.01705 0.53349 0.54049 0.62932 0.99643]
 [0.20805 0.34077 0.42702 0.69575 0.98095]
 [0.20076 0.21462 0.38053 0.83339 0.90501]
 [0.25612 0.4307  0.55081 0.70361 0.9735 ]
 [0.16665 0.20668 0.37982 0.51543 0.81744]
 [0.13285 0.19681 0.64619 0.67618 0.95548]
 [0.0954  0.15893 0.1778  0.82907 0.85157]
 [0.0357  0.05716 0.57567 0.65937 0.73366]
 [0.07702 0.4562  0.76944 0.82103 0.916  ]
 [0.5518  0.64742 0.66451 0.7338  0.91211]
 [0.17453 0.32893 0.52037 0.5563  0.56768]
 [0.02216 0.19925 0.2653  0.3411  0.41287]
 [0.17815 0.29904 0.8403  0.97766 0.98319]
 [0.07308 0.09722 0.09784 0.22584 0.67731]
 [0.35443 0.39604 0.69795 0.88029 0.92622]
 [0.33545 0.34728 0.74608 0.75475 0.95533]
 [0.2203  0.365   0.4159  0.59608 0.99952]
 [0.03393 0.10254 0.5943  0.74595 0.78017]
 [0.12923 0.44626 0.46172 0.59161 0.79663]
 [0.00619 0.28084 0.34644 0.74114 0.77569]
 [0.0682  0.29094 0.43181 0.51438 0.67378]
 [0.01926 0.35878 0.57553 0.63515 0.99125]
 [0.02842 0.05103 0.332   0.72988 0.89906]
 [0.19701 0

  1%|▏         | 7/500 [00:05<06:51,  1.20it/s]

[[0.22028 0.44015 0.63521 0.7098  0.91727]
 [0.20996 0.46823 0.56665 0.81507 0.88622]
 [0.21766 0.51974 0.5769  0.61567 0.8076 ]
 [0.16612 0.19162 0.53477 0.74359 0.75071]
 [0.03743 0.14821 0.44936 0.53964 0.6876 ]
 [0.04719 0.15777 0.41774 0.48587 0.7295 ]
 [0.01032 0.02572 0.22271 0.75778 0.98319]
 [0.11733 0.2803  0.39785 0.82376 0.99472]
 [0.06522 0.14655 0.36101 0.36465 0.81113]
 [0.02555 0.13199 0.2647  0.49174 0.86109]
 [0.02689 0.22083 0.7007  0.93702 0.9769 ]
 [0.20268 0.60892 0.61426 0.65701 0.83515]
 [0.18955 0.67546 0.91311 0.9326  0.95212]
 [0.13019 0.24377 0.5117  0.56257 0.63341]
 [0.05235 0.22969 0.35475 0.45207 0.56745]
 [0.19275 0.51262 0.7886  0.84318 0.91006]
 [0.15801 0.29308 0.32735 0.39505 0.6672 ]
 [0.00285 0.21295 0.21666 0.47366 0.71797]
 [0.02662 0.05602 0.16911 0.31703 0.78203]
 [0.03058 0.28908 0.41268 0.6594  0.82647]
 [0.05877 0.47152 0.58021 0.77966 0.95509]
 [0.07388 0.09248 0.32558 0.43515 0.63521]
 [0.0371  0.51652 0.75904 0.83329 0.87622]
 [0.21718 0




TypeError: only integer scalar arrays can be converted to a scalar index

In [37]:
params = (0, 1)
uniform_distribution = np.random.uniform

matrix_generator = matrix_gen.MatrixGenerator(
    uniform_distribution,
    params,
    N, N_INTVS
)

cost_matrix = matrix_generator.get_new_matrix(sort_rows=True)

# Leximin assignment
lex_assigner = leximin_assignment.LeximinAssignmentHelperV3(
    cost_matrix, CAPACITIES
)
with warnings.catch_warnings():  # temporarily suspense warnings
    warnings.simplefilter('ignore')
    lex_assignments = lex_assigner.solve()

running_assigned_agents = []
agents = [i for i in range(N)]

for j in range(N_INTVS - 1, -1, -1):
    lowest_cost_agents = np.argsort(
        cost_matrix[np.isin(agents, running_assigned_agents, invert=True), j]
    )[: CAPACITIES[j]]
    if np.any(lex_assignments[lowest_cost_agents] != j):
        print(j)
        print(cost_matrix)
        print(lex_assignments)
        break

    running_assigned_agents += list(lowest_cost_agents)

3
[[0.19513 0.40333 0.66096 0.71211 0.86341]
 [0.04091 0.32399 0.38177 0.73939 0.98142]
 [0.05431 0.0641  0.19637 0.30165 0.92575]
 [0.12493 0.55796 0.70613 0.95257 0.96849]
 [0.09637 0.22632 0.54106 0.55554 0.91837]
 [0.2398  0.61091 0.75205 0.90945 0.93692]
 [0.22523 0.53709 0.62229 0.68687 0.76582]
 [0.27167 0.33972 0.40519 0.91354 0.93223]
 [0.02699 0.0681  0.62126 0.64044 0.95154]
 [0.28991 0.54203 0.62675 0.794   0.91647]]
[4 2 3 0 3 0 4 2 1 1]


In [39]:
lowest_cost_agents

array([1, 3])