In [217]:
import numpy as np
import pandas as pd
from heapq import heappush, heappop, heapify
from itertools import combinations as comb


In [218]:
def rewardFunction(current_situation, split_rate, total_sum_label ,noise = False):
    # print(total_sum_label - current_situation)
    diff_num = total_sum_label * split_rate - current_situation
    if noise == True:
        diff_num = makeNoise(diff_num, scale = np.abs(np.mean(diff_num)/10))
    # if the reward is negative, increase its weight
    diff_num[diff_num > 0] *= 5
    return np.sum(diff_num)

def makeNoise(score, scale):
    gumbel_noise = np.random.gumbel(0, scale, size=score.shape) + score
    return gumbel_noise



In [219]:
def rewardFunction2(current_situation, split_rate, total_sum_label ,noise = False):
    # print(total_sum_label - current_situation)
    noise = False
    diff_num = total_sum_label * split_rate - current_situation
    if noise == True:
        diff_num = makeNoise(diff_num, scale = np.abs(np.mean(diff_num)/10))
    # if the diff is negative, increase its weight
    
    # normalize the diff
    
    diff_num = np.divide(diff_num, total_sum_label)
    
    diff_num[diff_num > 0] = np.exp(diff_num[diff_num > 0])
    diff_num[diff_num < 0] = -np.exp(diff_num[diff_num < 0])
    return np.sum(diff_num)

In [220]:
class RandomGroupSplitter:
    def __init__(self, noise, force_bound, tolerance, split_rate, label_matrix):
        self.noise = noise
        self.force_bound = force_bound
        self.tolerance = tolerance
        self.split_rate = split_rate
        self.label_matrix = label_matrix.astype(np.int64)
        self.total_sum_label = np.sum(self.label_matrix, axis = 0)
        self.tried_combination = []
        
    def get_one_combination(self):
        end = False
        current_result = np.zeros(self.label_matrix.shape[1]).astype(np.int64)
        index = list(range(self.label_matrix.shape[0]))
        result_idx = []
        
        while not end:
            # new a queue
            heap = []
            for i in index:
                diff_result = rewardFunction(current_result + self.label_matrix[i], self.split_rate ,self.total_sum_label, noise= self.noise)
                heappush(heap, (diff_result, i))

            if heap[0][0] < 0:
                end = True
                break
            best_result = heappop(heap)
            current_result += self.label_matrix[best_result[1]]
            index.remove(best_result[1])
            print(f"best_result: {best_result}")
            print(f"label matrix: {self.label_matrix[best_result[1]]}")
            print(f"current_result: {current_result}")
            result_idx.append(best_result[1])
            print(result_idx)
        
        
        result_idx = sorted(result_idx)
        self.tried_combination.append(tuple(result_idx))
        
        return current_result, result_idx
    
    def get_multi_combination(self, noOfCombination, top_k = 1000):
        end = False
        combination_list = comb(range(self.label_matrix.shape[0]), 2)
        # print(list(combination_list))
        # print(label_matrix)
        counter = 0
        result = None
        while not end:
            counter += 1
            queue = []
            
            for i in list(combination_list):
                # print(i)
                # print(label_matrix[i, :])
                # print(np.sum(label_matrix[i, :], axis=0))
                diff_result = rewardFunction2(np.sum(self.label_matrix[i, :], axis=0), self.split_rate ,self.total_sum_label, noise= self.noise)
                heappush(queue, (diff_result, i))
                # print(queue)
            print(len(queue))
            # print(queue)
            print(f"best result: {queue[0]}")
            if queue[0][0] < 0 or len(queue) < 50:
                end = True
                result = queue
                break
            k_top_index = [heappop(queue)[1] for _ in range(np.min((top_k, len(queue))))]
            # print(k_top_index)
            
            combination_list = list(comb(k_top_index, 2))
            # print(combination_list)
            # remove the duplicated index tuple, i.e. (1,2), (2,3) -> (1,2,3)
            combination_list = list(set([tuple(sorted(set(a + b))) for a, b in combination_list]))
            print(f"{counter}th iteration, len of combination_list: {len(combination_list)}")
            
            
            if counter == 10: 
                print(f"cannot converge")
                result = queue
                break
            
            # print(combination_list)
        top_k_group_combination = [heappop(result) for _ in range(noOfCombination)]
        
        top_k_distribution = np.zeros((noOfCombination, self.label_matrix.shape[1]))
        
        for idx,i in enumerate(top_k_group_combination):
            index = i[1]
            top_k_distribution[idx] = (np.sum(self.label_matrix[index,:], axis=0))
            
        print(top_k_distribution)
        print(top_k_group_combination)
        top_k_group_combination = [i[1] for i in top_k_group_combination]
        
        for i in top_k_group_combination:
            self.tried_combination.append(tuple(i))
        
        return top_k_distribution, top_k_group_combination

    def get_tried_combination(self):
        print(self.tried_combination)
        return list(set(self.tried_combination))
    
    def get_remain_test_combination(self, train_combination):
        all_combination = list(range(self.label_matrix.shape[0]))
        tried_combination = train_combination
        remain_combination = sorted(list(set(all_combination) - set(tried_combination)))
        return remain_combination

In [221]:
df = pd.read_csv('distribution.csv')
label_matrix = np.array(df.values[:, 1:-1])

In [222]:
RGS = RandomGroupSplitter(noise = True, force_bound = False, tolerance = 0.1, split_rate = 0.75, label_matrix = label_matrix)

In [223]:
current_result, result_idx = RGS.get_one_combination()

best_result: (11388.609462494289, 36)
label matrix: [  0   0   0 725]
current_result: [  0   0   0 725]
[36]
best_result: (9594.519546137497, 30)
label matrix: [ 9  0  0 28]
current_result: [  9   0   0 753]
[36, 30]
best_result: (9167.736981086995, 35)
label matrix: [ 47   0 199  30]
current_result: [ 56   0 199 783]
[36, 30, 35]
best_result: (7626.887218225989, 40)
label matrix: [130   0   0 229]
current_result: [ 186    0  199 1012]
[36, 30, 35, 40]
best_result: (6266.292321841546, 34)
label matrix: [ 90   0   0 117]
current_result: [ 276    0  199 1129]
[36, 30, 35, 40, 34]
best_result: (5748.912008085197, 13)
label matrix: [81 52  0  0]
current_result: [ 357   52  199 1129]
[36, 30, 35, 40, 34, 13]
best_result: (4826.781094861635, 22)
label matrix: [22  0  0 79]
current_result: [ 379   52  199 1208]
[36, 30, 35, 40, 34, 13, 22]
best_result: (4512.577374131799, 5)
label matrix: [12  0 24  0]
current_result: [ 391   52  223 1208]
[36, 30, 35, 40, 34, 13, 22, 5]
best_result: (4139.05

In [224]:
print(np.divide(current_result, np.sum(label_matrix, axis = 0)))

print(result_idx)

[0.7522603978300181 0.6428571428571429 0.8766066838046273
 0.8231155778894472]
[2, 5, 9, 10, 13, 17, 21, 22, 25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 40]


In [225]:
multi_result, multi_result2 = RGS.get_multi_combination(10, 1000)

820
best result: (6.86378694107514, (35, 36))


1th iteration, len of combination_list: 111930
111930
best result: (4.139780910671986, (2, 35, 36, 37))
2th iteration, len of combination_list: 174041
174041
best result: (3.005722101272507, (2, 8, 13, 34, 35, 36, 37, 40))
3th iteration, len of combination_list: 66928
66928
best result: (0.5614655933254309, (2, 8, 9, 13, 21, 27, 28, 34, 35, 36, 37, 40))
4th iteration, len of combination_list: 51320
51320
best result: (-1.90318487753063, (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40))
[[ 821.  271.  317. 1497.]
 [ 835.  271.  317. 1437.]
 [ 807.  271.  317. 1523.]
 [ 833.  271.  317. 1412.]
 [ 799.  271.  317. 1531.]
 [ 835.  271.  317. 1403.]
 [ 789.  271.  317. 1531.]
 [ 779.  271.  317. 1514.]
 [ 843.  271.  317. 1395.]
 [ 781.  271.  317. 1539.]]
[(-1.90318487753063, (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40)), (-1.879916060308184, (2, 8, 9, 11, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40)), (-1.877396890391712, (2, 8, 9, 13, 17, 21, 

In [226]:
print(np.divide(multi_result, np.sum(label_matrix, axis = 0)))

[[0.7423146473779385 0.7742857142857142 0.8149100257069408
  0.7522613065326633]
 [0.7549728752260397 0.7742857142857142 0.8149100257069408
  0.7221105527638191]
 [0.7296564195298373 0.7742857142857142 0.8149100257069408
  0.7653266331658292]
 [0.7531645569620253 0.7742857142857142 0.8149100257069408
  0.7095477386934673]
 [0.7224231464737794 0.7742857142857142 0.8149100257069408
  0.7693467336683417]
 [0.7549728752260397 0.7742857142857142 0.8149100257069408
  0.7050251256281407]
 [0.713381555153707 0.7742857142857142 0.8149100257069408
  0.7693467336683417]
 [0.7043399638336347 0.7742857142857142 0.8149100257069408
  0.7608040201005025]
 [0.7622061482820977 0.7742857142857142 0.8149100257069408
  0.7010050251256281]
 [0.7061482820976492 0.7742857142857142 0.8149100257069408
  0.7733668341708543]]


In [232]:
print(multi_result2)
print(len(multi_result2))
for i in multi_result2:
    print(i)
    remain = RGS.get_remain_test_combination(i)
    print(remain, len(remain))

[(2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 11, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 34, 35, 36, 37, 39, 40), (2, 8, 9, 10, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 10, 11, 13, 17, 21, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 23, 25, 28, 29, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 33, 34, 35, 36, 37, 39, 40), (2, 8, 9, 11, 13, 17, 21, 23, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 10, 13, 17, 21, 22, 25, 28, 29, 33, 34, 35, 36, 37, 40)]
10
(2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40)
[0, 1, 3, 4, 5, 6, 7, 10, 11, 12, 14, 15, 16, 18, 19, 20, 23, 24, 26, 27, 30, 31, 38, 39] 24
(2, 8, 9, 11, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40)
[0, 1, 3, 4, 5, 6, 7, 10, 12, 14, 15, 16, 18, 19, 20, 22, 23, 24, 26, 27, 30, 31, 38, 39] 24
(2, 8, 9

In [228]:
my_trial = RGS.get_tried_combination()

[(2, 5, 9, 10, 13, 17, 21, 22, 25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 11, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 34, 35, 36, 37, 39, 40), (2, 8, 9, 10, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 10, 11, 13, 17, 21, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 23, 25, 28, 29, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 33, 34, 35, 36, 37, 39, 40), (2, 8, 9, 11, 13, 17, 21, 23, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 10, 13, 17, 21, 22, 25, 28, 29, 33, 34, 35, 36, 37, 40)]


In [229]:
print(len(my_trial))    

11


In [233]:
from RandomGroupSpliter import RandomGroupSplitter

In [234]:
RGS = RandomGroupSplitter(noise = True, force_bound = False, tolerance = 0.1, split_rate = 0.75, label_matrix = label_matrix)

In [235]:
current_result, result_idx = RGS.get_one_combination()

In [236]:
print(current_result, result_idx)

[ 797  229  317 1636] [2, 8, 10, 13, 17, 21, 23, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40]


In [237]:
a, b = RGS.get_multi_combination(10, 1000)

1th iteration, len of combination_list: 111930
2th iteration, len of combination_list: 174041
3th iteration, len of combination_list: 66928
4th iteration, len of combination_list: 51320


In [238]:
print(a, b)

[[ 821.  271.  317. 1497.]
 [ 835.  271.  317. 1437.]
 [ 807.  271.  317. 1523.]
 [ 833.  271.  317. 1412.]
 [ 799.  271.  317. 1531.]
 [ 835.  271.  317. 1403.]
 [ 789.  271.  317. 1531.]
 [ 779.  271.  317. 1514.]
 [ 843.  271.  317. 1395.]
 [ 781.  271.  317. 1539.]] [(2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 11, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 25, 28, 29, 32, 34, 35, 36, 37, 39, 40), (2, 8, 9, 10, 13, 17, 21, 25, 28, 29, 32, 33, 34, 35, 36, 37, 40), (2, 8, 9, 10, 11, 13, 17, 21, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 22, 23, 25, 28, 29, 33, 34, 35, 36, 37, 40), (2, 8, 9, 13, 17, 21, 23, 25, 28, 29, 33, 34, 35, 36, 37, 39, 40), (2, 8, 9, 11, 13, 17, 21, 23, 25, 28, 29, 32, 34, 35, 36, 37, 40), (2, 8, 9, 10, 13, 17, 21, 22, 25, 28, 29, 33, 34, 35, 36, 37, 40)]
