In [1]:
import numpy as np
import argparse
import time
import os
import collections
import json

from utils.data_utils import load_dataset_numpy

import scipy.spatial.distance

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_in", default='MNIST',
                    help="dataset to be used")
parser.add_argument("--norm", default='l2',
                    help="norm to be used")
parser.add_argument('--num_samples', type=int, default=None)
parser.add_argument('--n_classes', type=int, default=2)
parser.add_argument('--eps', type=float, default=None)
parser.add_argument('--approx_only', dest='approx_only', action='store_true')
parser.add_argument('--use_test', dest='use_test', action='store_true')
parser.add_argument('--track_hard', dest='track_hard', action='store_true')
parser.add_argument('--new_marking_strat', type=str, default=None)

_StoreAction(option_strings=['--new_marking_strat'], dest='new_marking_strat', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, help=None, metavar=None)

In [3]:
args = parser.parse_args("--dataset_in=MNIST --num_samples=2000".split())

In [4]:
train_data, test_data, data_details = load_dataset_numpy(args, data_dir='data',
														training_time=False)
DATA_DIM = data_details['n_channels']*data_details['h_in']*data_details['w_in']

X = []
Y = []

# Pytorch normalizes tensors (so need manual here!)
if args.use_test:
	for (x,y,_, _, _) in test_data:
		X.append(x/255.)
		Y.append(y)
else:
	for (x,y,_, _, _) in train_data:
		X.append(x/255.)
		Y.append(y)

X = np.array(X)
Y = np.array(Y)

num_samples = int(len(X)/2)
print(num_samples)

if not os.path.exists('distances'):
	os.makedirs('distances')

if 'MNIST' in args.dataset_in or 'CIFAR-10' in args.dataset_in:
	class_1 = 3
	class_2 = 7
	if args.use_test:
		dist_mat_name = args.dataset_in + '_test_' + str(class_1) + '_' + str(class_2) + '_' + str(num_samples) + '_' + args.norm + '.npy'
	else:
		dist_mat_name = args.dataset_in + '_' + str(class_1) + '_' + str(class_2) + '_' + str(num_samples) + '_' + args.norm + '.npy'
	X_c1 = X[:num_samples].reshape(num_samples, DATA_DIM)
	X_c2 = X[num_samples:].reshape(num_samples, DATA_DIM)
	if os.path.exists(dist_mat_name):
		D_12 = np.load('distances/' + dist_mat_name)
	else:
		if args.norm == 'l2':
			D_12 = scipy.spatial.distance.cdist(X_c1,X_c2,metric='euclidean')
		elif args.norm == 'linf':
			D_12 = scipy.spatial.distance.cdist(X_c1,X_c2,metric='chebyshev')
		np.save('distances/' + dist_mat_name, D_12)


2000


In [27]:
if args.norm == 'l2' and 'MNIST' in args.dataset_in:
	# eps_list = np.linspace(3.2,3.8,4)
	eps_list = np.linspace(3.8,3.8,1)
	# eps_list=[2.6,2.8]
elif args.norm == 'l2' and 'CIFAR-10' in args.dataset_in:
	eps_list = np.linspace(4.0,10.0,13)
elif args.norm == 'linf' and 'MNIST' in args.dataset_in:
	eps_list = np.linspace(0.1,0.5,5)
elif args.norm == 'linf' and 'CIFAR-10' in args.dataset_in:
	eps_list = np.linspace(0.1,0.5,5)

if args.eps is not None:
	eps_list = [args.eps]

print(eps_list)

[3.8]


In [28]:
for eps in eps_list:
	print(eps)
    # Add edge if cost 0
	edge_matrix = D_12 <= 2*eps
	edge_matrix = edge_matrix.astype(float)

3.8


In [29]:
n_1 = args.num_samples
n_2 = args.num_samples

In [30]:
from scipy.sparse import csr_matrix
# from scipy.sparse.csgraph import maximum_flow
# import pyximport
# pyximport.install()
from utils.flow import maximum_flow
import numpy as np

n_1 = 9
n_2 = 9

edge_matrix=np.zeros((n_1,n_2))
edge_matrix[0,0]=1
edge_matrix[1,0]=1
edge_matrix[2,1]=1
edge_matrix[3,1]=1
edge_matrix[3,2]=1
edge_matrix[4,2]=1
edge_matrix[5,3]=1
edge_matrix[6,4]=1
edge_matrix[7,5]=1
edge_matrix[7,6]=1
edge_matrix[7,7]=1

In [31]:
graph_rep = []
for i in range(n_1+n_2+2):
    graph_rep.append([])
    if i==0:
        #source
        for j in range(n_1+n_2+2):
            if j==0:
                graph_rep[i].append(0)
            elif 1<=j<=n_1:
                graph_rep[i].append(n_2)
            elif n_1<j<=n_1+n_2+1:
                graph_rep[i].append(0)
    elif 1<=i<=n_1:
        # LHS vertices
        for j in range(n_1+n_2+2):
            if j<=n_1:
                graph_rep[i].append(0)
            elif n_1<j<=n_1+n_2:
                if edge_matrix[i-1,j-n_1-1]:
                    graph_rep[i].append(n_1*n_2)
                else:
                    graph_rep[i].append(0)
            elif n_1+n_2<j:
                graph_rep[i].append(0)
    elif n_1<i<=n_1+n_2:
        #RHS vertices
        for j in range(n_1+n_2+2):
            if j<=n_1+n_2:
                graph_rep[i].append(0)
            elif j>n_1+n_2:
                graph_rep[i].append(n_1)
    elif i==n_1+n_2+1:
        #Sink
        for j in range(n_1+n_2+2):
            graph_rep[i].append(0)

In [32]:
graph_rep_array=np.array(graph_rep)

In [33]:
graph_rep_array

array([[   0, 2000, 2000, ...,    0,    0,    0],
       [   0,    0,    0, ...,    0,    0,    0],
       [   0,    0,    0, ...,    0,    0,    0],
       ...,
       [   0,    0,    0, ...,    0,    0, 2000],
       [   0,    0,    0, ...,    0,    0, 2000],
       [   0,    0,    0, ...,    0,    0,    0]])

In [12]:
graph=csr_matrix(graph_rep_array)

In [13]:
A=maximum_flow(graph, 0, n_1+n_2+1)

In [14]:
A.flow_value

1402000

In [15]:
A.pred_edge

array([-1,  0, -1, ..., -1, -1, -1], dtype=int32)

In [16]:
A.path_edges

[(0, 1),
 (0, 3),
 (0, 4),
 (0, 7),
 (0, 10),
 (0, 12),
 (0, 13),
 (0, 15),
 (0, 17),
 (0, 18),
 (0, 19),
 (0, 22),
 (0, 24),
 (0, 26),
 (0, 27),
 (0, 29),
 (0, 30),
 (0, 32),
 (0, 34),
 (0, 35),
 (0, 36),
 (0, 37),
 (0, 38),
 (0, 40),
 (0, 41),
 (0, 42),
 (0, 43),
 (0, 44),
 (0, 45),
 (0, 47),
 (0, 48),
 (0, 49),
 (0, 50),
 (0, 54),
 (0, 56),
 (0, 59),
 (0, 62),
 (0, 63),
 (0, 65),
 (0, 66),
 (0, 67),
 (0, 68),
 (0, 69),
 (0, 70),
 (0, 72),
 (0, 73),
 (0, 75),
 (0, 76),
 (0, 79),
 (0, 80),
 (0, 81),
 (0, 82),
 (0, 83),
 (0, 85),
 (0, 86),
 (0, 87),
 (0, 89),
 (0, 90),
 (0, 92),
 (0, 93),
 (0, 99),
 (0, 100),
 (0, 101),
 (0, 102),
 (0, 103),
 (0, 105),
 (0, 106),
 (0, 107),
 (0, 108),
 (0, 109),
 (0, 111),
 (0, 114),
 (0, 117),
 (0, 120),
 (0, 121),
 (0, 122),
 (0, 123),
 (0, 124),
 (0, 129),
 (0, 132),
 (0, 133),
 (0, 134),
 (0, 135),
 (0, 137),
 (0, 138),
 (0, 144),
 (0, 145),
 (0, 147),
 (0, 148),
 (0, 149),
 (0, 150),
 (0, 152),
 (0, 153),
 (0, 154),
 (0, 155),
 (0, 158),
 (0, 159)

In [17]:
A.residual.toarray()

array([[    0,     0,  2000, ...,     0,     0,     0],
       [    0,     0,     0, ...,     0,     0,     0],
       [-2000,     0,     0, ...,     0,     0,     0],
       ...,
       [    0,     0,     0, ...,     0,     0,     0],
       [    0,     0,     0, ...,     0,     0,     0],
       [    0,     0,     0, ...,     0,     0,     0]], dtype=int32)

In [18]:
remainder_array=graph_rep_array-A.residual.toarray()
remainder_array

array([[   0, 2000,    0, ...,    0,    0,    0],
       [   0,    0,    0, ...,    0,    0,    0],
       [2000,    0,    0, ...,    0,    0,    0],
       ...,
       [   0,    0,    0, ...,    0,    0, 2000],
       [   0,    0,    0, ...,    0,    0, 2000],
       [   0,    0,    0, ...,    0,    0,    0]])

import scipy
nz_flow_tuple=scipy.sparse.find(A.residual)

gz_idx=np.where(nz_flow_tuple[2]>0)

gz_row_idx=nz_flow_tuple[0][gz_idx]
gz_col_idx=nz_flow_tuple[1][gz_idx]

len(gz_idx[0])

gz_row_idx

gz_col_idx

slice_row=np.unique(gz_row_idx)

slice_col=np.unique(gz_col_idx)

overall_slice=np.union1d(slice_row,slice_col)

graph_rep_array=np.array(graph_rep)

graph_lvl_one=graph_rep_array[overall_slice]

graph_lvl_one[:,overall_slice]

In [None]:
# Sort into lhs and rhs, rescale and run again; add termination condition for recursion down the graph; retain original numbering of array elements

In [34]:
def set_classifier_prob_full_flow(top_level_vertices,n_1_curr,n_2_curr):
    for item in top_level_vertices:
        if item !=0 and item != sink_idx:
            classifier_probs[item-1,0]=n_1_curr/(n_1_curr+n_2_curr)
            classifier_probs[item-1,1]=n_2_curr/(n_1_curr+n_2_curr)

In [35]:
def set_classifier_prob_no_flow(top_level_vertices):
    for item in top_level_vertices:
        if item !=0 and item != sink_idx:
            if item<=n_1:
                classifier_probs[item-1,0]=1
                classifier_probs[item-1,1]=0
            elif item>n_1:
                classifier_probs[item-1,0]=0
                classifier_probs[item-1,1]=1

In [36]:
def graph_rescale(graph_rep_curr,top_level_indices):
    n_1_curr=len(np.where(top_level_indices<=n_1)[0])-1
    n_2_curr=len(np.where(top_level_indices>n_1)[0])-1
    # source rescale
    print(graph_rep_curr[0])
    graph_rep_curr[0,:]=graph_rep_curr[0,:]/n_2
    graph_rep_curr[0,:]*=n_2_curr
    print(graph_rep_curr[0])
    # bipartite graph edge scale
    graph_rep_curr[1:n_1_curr+1,:]=graph_rep_curr[1:n_1_curr+1,:]/(n_1*n_2)
    graph_rep_curr[1:n_1_curr+1,:]*=(n_1_curr*n_2_curr)
    # sink edges rescale
    graph_rep_curr[n_1_curr+1:,:]=graph_rep_curr[n_1_curr+1:,:]/n_1
    graph_rep_curr[n_1_curr+1:,:]*=n_1_curr
    return graph_rep_curr,n_1_curr,n_2_curr

In [37]:
def find_flow_and_split(top_level_indices):
    top_level_indices_1=None
    top_level_indices_2=None
    #Create subgraph from index array provided
    graph_rep_curr = graph_rep_array[top_level_indices]
    graph_rep_curr = graph_rep_curr[:,top_level_indices]
    graph_rep_curr,n_1_curr,n_2_curr = graph_rescale(graph_rep_curr,top_level_indices)
    graph_curr=csr_matrix(graph_rep_curr)
    flow_curr = maximum_flow(graph_curr,0,len(top_level_indices)-1)
    # Checking if full flow occurred, so no need to split
    if flow_curr.flow_value==n_1_curr*n_2_curr:
        set_classifier_prob_full_flow(top_level_indices,n_1_curr,n_2_curr)
        return top_level_indices_1,top_level_indices_2, flow_curr
    elif flow_curr.flow_value==0:
        set_classifier_prob_no_flow(top_level_indices)
        return top_level_indices_1,top_level_indices_2, flow_curr
    # Finding remaining capacity edges
    edge_list_curr=flow_curr.path_edges
#     print(edge_list_curr)
    gz_idx = []
    for item in edge_list_curr:
        gz_idx.append(item[0])
        gz_idx.append(item[1])
    if len(gz_idx)>0:
        gz_idx=np.array(gz_idx)
        gz_idx_unique=np.unique(gz_idx)
        top_level_gz_idx=top_level_indices[gz_idx_unique]
        top_level_gz_idx=np.insert(top_level_gz_idx,len(top_level_gz_idx),sink_idx)
        top_level_indices_1=top_level_gz_idx
    else:
        top_level_gz_idx=np.array([0,sink_idx])
    # Indices without flow
    top_level_z_idx=np.setdiff1d(top_level_indices,top_level_gz_idx)
    if len(top_level_z_idx)>0:
        # Add source and sink back to zero flow idx array
        top_level_z_idx=np.insert(top_level_z_idx,0,0)
        top_level_z_idx=np.insert(top_level_z_idx,len(top_level_z_idx),sink_idx)
        top_level_indices_2=top_level_z_idx
    
    return top_level_indices_1,top_level_indices_2, flow_curr

In [38]:
import queue
q = queue.Queue()
# Initial graph indices
q.put(np.arange(n_1+n_2+2))
sink_idx=n_1+n_2+1
count=0
classifier_probs=np.zeros((n_1+n_2,2))
while not q.empty():
    print(q.qsize())
    curr_idx_list=q.get()
    print(q.qsize())
    list_1, list_2, flow_curr=find_flow_and_split(curr_idx_list)
    print(list_1,list_2,flow_curr.flow_value)
#     if list_1 is None and list_2 is None:
#         trial_var=set_global_probs(flow_curr,curr_idx_list)
    if list_1 is not None:
#         flow_possible_1=check_flow_possible(list_1)
        q.put(list_1)
    if list_2 is not None:
#         flow_possible_2=check_flow_possible(list_2)
        q.put(list_2)

1
0
[   0 2000 2000 ...    0    0    0]
[   0 2000 2000 ...    0    0    0]
[   0    1    2 ... 3980 3985 4001] [   0    5    6 ... 3999 4000 4001] 956000
2
1
[   0 2000 2000 ...    0    0    0]
[  0 112 112 ...   0   0   0]
[   0    1    2 ... 1999 2000 4001] [   0    9   11   16   33   55   60   98  116  125  126  131  140  157
  160  188  197  201  203  206  207  212  214  224  227  238  244  269
  273  274  279  303  320  323  337  343  348  364  366  375  380  400
  435  476  512  515  519  521  529  540  559  590  613  641  646  650
  658  665  669  670  684  690  693  698  717  719  728  732  747  749
  759  764  766  769  772  779  785  807  814  816  847  849  867  874
  897  899  900  902  910  915  924  939  944  949  972  976  977  984
  990  998 1001 1005 1017 1030 1046 1048 1058 1059 1060 1074 1077 1096
 1116 1126 1127 1135 1136 1139 1143 1165 1167 1182 1187 1188 1189 1194
 1199 1206 1208 1216 1227 1228 1234 1235 1237 1242 1248 1258 1260 1267
 1273 1289 1322 1347 1378 138

[   0   51  252  265  307  336  376  434  471  480  509  577  704  752
  812  828  883  892  894  907  935  942  951  954  986 1044 1075 1099
 1229 1299 1321 1423 1482 1663 1666 1903 1923 1960 2012 2025 2029 2030
 2031 2033 2043 2050 2054 2057 2060 2062 2065 2072 2078 2080 2083 2089
 2091 2092 2094 2098 2103 2129 2132 2141 2145 2177 2179 2180 2188 2189
 2195 2198 2199 2204 2208 2214 2218 2232 2234 2236 2238 2242 2245 2248
 2249 2256 2265 2268 2269 2274 2285 2286 2287 2298 2301 2321 2323 2327
 2330 2334 2336 2337 2342 2345 2347 2349 2351 2352 2355 2358 2366 2371
 2376 2381 2382 2386 2406 2415 2422 2423 2424 2426 2434 2444 2446 2448
 2451 2453 2455 2464 2466 2474 2475 2489 2490 2495 2498 2500 2502 2504
 2513 2519 2523 2527 2532 2540 2543 2546 2550 2562 2574 2575 2579 2584
 2585 2587 2595 2597 2599 2600 2665 2674 2691 2696 2698 2701 2705 2710
 2711 2712 2716 2741 2744 2745 2757 2761 2763 2768 2772 2782 2786 2798
 2803 2812 2815 2817 2818 2837 2854 2884 2890 2904 2906 2907 2911 2918
 2925 

[   0  116  125  157  160  206  212  227  269  273  274  303  323  343
  348  364  512  529  540  559  646  650  658  690  698  717  719  728
  732  759  764  769  847  897  899  910  944  976  977  984  990 1001
 1017 1046 1116 1187 1188 1199 1216 1248 1258 1260 1267 1273 1347 1378
 1380 1395 1401 1413 1422 1431 1465 1469 1515 1527 1547 1553 1607 1624
 1652 1660 1665 1673 1696 1707 1765 1782 1808 1812 1865 1871 1902 1914
 1916 1918 1932 1968 1970 2022 2038 2058 2168 2244 2333 2339 2370 2411
 2414 2545 2553 2609 2611 2618 2649 2652 2683 2724 2734 2737 2752 2793
 2795 2810 2821 2873 2901 2924 2951 2964 2983 3049 3055 3142 3157 3193
 3248 3282 3315 3334 3344 3362 3364 3399 3453 3463 3487 3546 3613 3665
 3725 3740 3757 3824 3869 3985 4001] [   0    9  131  197  203  366  476  613  684  779  814  816  874  900
  949  972 1139 1234 1237 1289 1456 1485 1526 1543 1574 1921 1961 2035
 2479 2521 2581 2790 2866 2880 3004 3155 3162 3321 3421 3446 3511 3575
 3632 3695 3706 3761 3790 3832 3969 4001

[   0    6   46   58   74   78   84  110  112  128  141  156  179  208
  223  233  237  249  281  288  322  346  421  427  429  438  439  460
  563  585  605  606  610  649  654  664  683  685  689  700  702  709
  731  734  783  784  901  912  952  980 1053 1078 1134 1159 1202 1211
 1217 1222 1223 1281 1309 1324 1325 1360 1371 1385 1387 1406 1415 1416
 1421 1430 1432 1433 1443 1452 1484 1503 1552 1554 1584 1589 1606 1610
 1611 1621 1623 1645 1656 1659 1661 1683 1768 1771 1801 1820 1835 1860
 1882 1886 1888 1906 1917 1984 1985 2003 2037 2044 2052 2066 2069 2088
 2097 2102 2108 2109 2123 2138 2146 2148 2150 2152 2159 2160 2163 2169
 2176 2181 2182 2197 2201 2207 2212 2215 2221 2222 2230 2237 2246 2247
 2261 2277 2281 2294 2300 2308 2318 2319 2338 2346 2350 2357 2359 2363
 2369 2373 2377 2408 2419 2435 2452 2458 2461 2465 2473 2477 2487 2506
 2507 2514 2517 2530 2534 2547 2549 2557 2564 2578 2580 2591 2602 2606
 2607 2624 2627 2632 2633 2636 2637 2642 2643 2666 2672 2675 2687 2689
 2702 

[   0  202  886  889  903  906 1193 2001 2005 2019 2023 2026 2028 2034
 2036 2040 2053 2074 2076 2077 2082 2095 2099 2101 2111 2115 2127 2143
 2144 2156 2164 2166 2171 2173 2174 2178 2196 2206 2217 2220 2235 2243
 2253 2257 2258 2260 2266 2271 2272 2278 2280 2297 2299 2303 2313 2315
 2320 2340 2361 2364 2365 2379 2413 2450 2454 2460 2482 2486 2505 2510
 2537 2565 2567 2570 2573 2577 2586 2588 2590 2596 2604 2608 2610 2612
 2613 2614 2615 2619 2621 2622 2625 2626 2634 2635 2640 2641 2645 2646
 2656 2659 2663 2669 2670 2676 2677 2682 2684 2685 2692 2733 2746 2748
 2751 2755 2767 2774 2783 2788 2791 2813 2816 2823 2828 2829 2840 2843
 2845 2847 2848 2850 2857 2859 2863 2872 2886 2894 2895 2897 2898 2899
 2905 2929 2945 2946 2965 2972 2996 3003 3006 3008 3017 3026 3032 3045
 3056 3058 3074 3100 3107 3112 3114 3117 3120 3121 3133 3138 3147 3151
 3152 3153 3156 3158 3165 3170 3177 3195 3205 3208 3215 3219 3246 3254
 3262 3264 3267 3274 3303 3304 3318 3320 3325 3333 3336 3338 3340 3347
 3376 

[   0  142  205  220  235  292  431  467  497  538  573  657  682  686
  712  721  795  820  829  841  855  908  930  948  956 1047 1083 1175
 1181 1220 1252 1269 1270 1277 1341 1424 1458 1476 1479 1487 1497 1506
 1579 1585 1604 1609 1626 1638 1662 1743 1746 1757 1769 1800 1824 1899
 2016 2046 2064 2079 2122 2131 2140 2151 2167 2202 2209 2227 2267 2312
 2324 2410 2428 2436 2576 2655 2664 2722 2769 2780 2801 2915 3059 3099
 3163 3181 3188 3222 3266 3302 3317 3337 3398 3435 3492 3496 3520 3543
 3587 3593 3614 3621 3682 3696 3720 3799 3809 3873 3909 3974 3987 4001] [   0   64  127  151  241  267  326  377  440  477  502  569  694  893
  929  963 1106 1183 1197 1207 1311 1509 1557 1760 1813 1867 1941 1953
 1964 1982 2112 2161 2251 2295 2356 2397 2525 2603 2650 2678 2731 2822
 2867 2871 2883 2923 2932 2984 2986 2992 3082 3160 3211 3247 3249 3268
 3419 3449 3451 3458 3469 3498 3518 3548 3551 3554 3648 3657 3838 3906
 3968 4001] 7404
17
16
[   0 2000 2000 2000 2000 2000 2000 2000 2000 2000 20

[   0   58   84  112  128  141  156  233  249  281  606  654  664  683
  685  702  709  901  912  952  980 1211 1217 1223 1281 1309 1371 1421
 1432 1443 1589 1656 1659 1768 1835 1888 2037 2052 2109 2146 2148 2150
 2152 2163 2176 2182 2201 2207 2212 2215 2237 2277 2294 2318 2346 2359
 2452 2517 2534 2549 2578 2607 2624 2636 2637 2643 2687 2689 2736 2804
 2826 2833 2849 2855 2860 2862 2878 2882 2888 2917 2927 2942 2953 2977
 2988 2990 2991 3013 3086 3095 3096 3104 3105 3109 3135 3141 3171 3175
 3202 3240 3275 3283 3285 3290 3339 3351 3386 3390 3402 3414 3466 3468
 3502 3521 3592 3598 3606 3627 3630 3634 3641 3711 3716 3804 3807 3834
 3835 3900 3923 3938 3942 3943 3945 3946 3952 3983 4001] [   0    6   46   74   78  110  179  208  223  237  288  322  346  421
  427  429  438  439  460  563  585  605  610  649  689  700  731  734
  783  784 1053 1078 1134 1159 1202 1222 1324 1325 1360 1385 1387 1406
 1415 1416 1430 1433 1452 1484 1503 1552 1554 1584 1606 1610 1611 1621
 1623 1645 1661 1683

[   0   51  434  480  509  883  942  954 2062 2132 2234 2238 2242 2265
 2274 2287 2336 2347 2349 2444 2474 2475 2489 2513 2519 2532 2540 2562
 2579 2595 2696 2701 2711 2768 2837 2904 2907 2925 2933 2934 2963 3044
 3051 3053 3068 3070 3084 3139 3159 3169 3179 3194 3200 3207 3220 3259
 3289 3296 3298 3545 3563 3564 3584 3590 3599 3617 3623 3654 3712 3729
 3753 3791 3868 3883 3913 3920 3948 3977 3995 4001] [   0  265  336  752  812  892  907  951 1044 2030 2031 2043 2054 2057
 2060 2065 2072 2078 2091 2092 2094 2103 2129 2145 2177 2179 2180 2208
 2214 2218 2269 2285 2286 2298 2323 2327 2330 2342 2351 2355 2371 2381
 2382 2386 2415 2423 2424 2426 2446 2448 2451 2464 2495 2502 2543 2550
 2575 2585 2600 2674 2761 2812 2815 2817 2918 2935 2952 2954 2955 2966
 2994 3042 3060 3103 3144 3167 3172 3176 3180 3182 3308 3341 3342 3346
 3348 3415 3437 3532 3544 3616 3626 3637 3639 3646 3647 3690 3691 3692
 3699 3747 3760 3797 3810 3833 3841 3842 3858 3871 3890 3894 3902 3904
 3912 3917 3949 3962 3971

None None 2400
22
21
[   0 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
 2000    0    0    0    0    0    0    0    0    0    0    0    0]
[ 0 11 11 11 11 11 11 11 11 11 11 11 11 11 11  0  0  0  0  0  0  0  0  0
  0  0  0]
[   0  197  613 1234 1921 2790 3321 3632 4001] [   0  203  476  684  779  949 1139 1237 1289 1574 1961 2521 2866 2880
 3004 3155 3162 3421 3790 4001] 152
23
22
[   0 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000    0
    0    0    0    0    0    0    0    0    0    0    0]
[ 0 11 11 11 11 11 11 11 11 11 11 11 11  0  0  0  0  0  0  0  0  0  0  0
  0]
None None 132
22
21
[   0 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0 

[   0    6   46   74   78  110  179  208  237  288  322  346  421  429
  438  439  585  605  610  649  689  700  731  734  783  784 1053 1078
 1134 1159 1202 1222 1324 1325 1360 1385 1387 1406 1415 1416 1430 1433
 1452 1484 1503 1552 1554 1584 1606 1610 1611 1621 1623 1645 1661 1683
 1771 1801 1820 1860 1882 1906 1917 1984 1985 2003 2044 2066 2069 2088
 2097 2102 2108 2123 2159 2160 2169 2181 2197 2221 2222 2230 2247 2261
 2281 2300 2308 2319 2338 2350 2357 2363 2369 2373 2377 2408 2419 2435
 2458 2461 2465 2473 2477 2487 2506 2507 2514 2530 2547 2557 2564 2580
 2591 2602 2606 2632 2633 2666 2672 2675 2702 2706 2719 2725 2728 2747
 2749 2771 2777 2779 2794 2796 2799 2802 2805 2806 2807 2808 2809 2814
 2824 2856 2868 2875 2877 2881 2885 2908 2941 2950 2968 2973 2978 2980
 2981 2989 3021 3024 3025 3027 3031 3039 3040 3050 3065 3111 3113 3128
 3134 3137 3150 3184 3192 3196 3197 3216 3217 3229 3230 3245 3252 3255
 3272 3273 3278 3280 3281 3286 3288 3295 3297 3299 3311 3352 3353 3354
 3360 

[   0  336  752  907  951 1044 2030 2057 2060 2065 2078 2092 2103 2145
 2177 2180 2208 2269 2286 2327 2342 2351 2371 2381 2382 2415 2424 2426
 2495 2543 2550 2575 2761 2812 2815 2817 2918 2935 2954 2955 2966 2994
 3042 3060 3103 3144 3167 3176 3180 3346 3532 3544 3616 3637 3639 3646
 3647 3691 3747 3760 3842 3858 3894 3902 3904 3971 3982 4001] [   0  265  812  892 2031 2043 2054 2072 2091 2094 2129 2179 2214 2218
 2285 2298 2323 2330 2355 2386 2423 2446 2448 2451 2464 2502 2585 2600
 2674 2952 3172 3182 3308 3341 3342 3348 3415 3437 3626 3690 3692 3699
 3797 3810 3833 3841 3871 3890 3912 3917 3949 3962 4001] 815
32
31
[   0 2000 2000 2000    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    

None None 3471
22
21
[   0 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000
 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0]
[ 0 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67 67
 67  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
[   0  156  683 1211 1217 1223 1656 1835 2148 2150 2152 2163 2182 2215
 2534 2804 2862 3096 3105 3175 3240 3283 3285 3386 3390 3711 3804 4001] [   0   58   84  128  141  249  281  606  654  709  952 1281 1371 1421
 1432 15

[   0  336  907 1044 2057 2065 2078 2145 2177 2180 2208 2286 2327 2342
 2382 2495 2575 2812 2817 2918 2954 2955 2994 3042 3060 3103 3144 3346
 3544 3616 3639 3647 3691 3747 3760 3842 3894 3904 3971 3982 4001] [   0  752  951 2030 2060 2092 2103 2269 2351 2371 2381 2415 2424 2426
 2543 2550 2761 2815 2935 2966 3167 3176 3180 3532 3637 3646 3858 3902
 4001] 302
19
18
[   0 2000 2000 2000    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0]
[ 0 48 48 48  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0]
[   0  892 2054 2072 2179 2298 2585 2600 3341 3342 3348 3415 3437 3690
 3699 3841 3949 4001] [   0  265  812 2031 2043 2091 2094 2129 2214 2218 2285 2323 2330 2355
 2386 2423 2446 2448 2451 2464 2502

In [39]:
classifier_probs

array([[1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       ...,
       [0.33333333, 0.66666667],
       [0.025     , 0.975     ],
       [0.01886792, 0.98113208]])

In [40]:
loss = 0.0
for i in range(len(classifier_probs)):
    if i<n_1:
        loss+=np.log(classifier_probs[i][0])
    elif i>=n_1:
        loss+=np.log(classifier_probs[i][1])
loss=loss/len(classifier_probs)
print(loss)

-0.2498621512686806
