In [None]:
import numpy as np
import pandas as pd

In [None]:
import itertools

In [None]:
from collections import namedtuple

In [None]:
from utils.inference_s1_stem_bb import Predictor

In [None]:
from utils.util_global_struct import process_bb_old_to_new

In [None]:
predictor = Predictor(model_ckpt='../2021_05_18/s1_training/result/run_32/model_ckpt_ep_79.pth',
                      num_filters=[128, 128, 256, 256, 512, 512],
                      filter_width=[3, 3, 5, 5, 7, 7],
                      hid_shared=[128, 128, 128, 256, 256, 256],
                      hid_output=[64], dropout=0)

In [None]:
df_data = pd.read_pickle('data/s2_train_len20_200_2000_pred_stem_0p5.pkl.gz')

In [None]:
for idx, data_row in df_data.iterrows():
    if len(data_row.seq) <= 40:
        print(idx, len(data_row.bounding_boxes))

In [None]:
# data_row = df_data.iloc[55]

# data_row = df_data.iloc[60]

data_row = df_data.iloc[230]


In [None]:
# seq = 'TGAAAAGGGAGGAATCTGTTTGCCCTCAGATATTTAGTTA'
seq = data_row.seq

In [None]:
seq

In [None]:
process_bb_old_to_new(data_row.bounding_boxes)

In [None]:
df_stem = predictor.predict_bb(seq=seq, threshold=0.5)

In [None]:
df_stem = df_stem.reset_index(drop=True)

In [None]:
df_stem

In [None]:
# compute pairwise bb conflict

BoundingBox = namedtuple("BoundingBox", ['bb_x', 'bb_y', 'siz_x', 'siz_y'])

# we use df index, make sure it's contiguous
assert df_stem.iloc[-1].name == len(df_stem) - 1
bbs = {}
for idx, row in df_stem.iterrows():
    bbs[idx] = BoundingBox(bb_x=row['bb_x'],
                          bb_y=row['bb_y'],
                          siz_x=row['siz_x'],
                          siz_y=row['siz_y'])
    
def range_overlap(r1, r2):
    if r2[0] < r1[1] <= r2[1] or r1[0] < r2[1] <= r1[1]:
        return True
    else:
        return False
    

def bb_conflict(bb1, bb2):
    r11 = (bb1.bb_x, bb1.bb_x+bb1.siz_x)
    r12 = (bb1.bb_y-bb1.siz_y, bb1.bb_y)
    r21 = (bb2.bb_x, bb2.bb_x+bb2.siz_x)
    r22 = (bb2.bb_y-bb2.siz_y, bb2.bb_y)
    if range_overlap(r11, r21) or range_overlap(r11, r22) or range_overlap(r12, r21) or range_overlap(r12, r22):
        return True
    else:
        return False


In [None]:
bb_conf_arr = np.zeros((len(bbs), len(bbs)))

for i in bbs.keys():
    bb1 = bbs[i]
    for j in bbs.keys():
        bb2 = bbs[j]
        # TODO only need to compute half
        bb_conf_arr[i, j] = bb_conflict(bb1, bb2)


In [None]:
df_stem

In [None]:
bb_conf_arr.astype(int)

In [None]:
assert np.all(bb_conf_arr.T == bb_conf_arr)

In [None]:
# brute force way
n_bbs = len(bbs)
all_combos = list(itertools.product([0, 1], repeat=n_bbs))

valid_combos = []
for c in all_combos:
    # find elements being included
    bb_inc = np.where(c)[0]
    # all pairs of elements
    bb_pairs = list(itertools.combinations(bb_inc, 2))
    # check if any pair violate constraint
    is_valid = True
    for bb_pair in bb_pairs:
        if bb_conf_arr[bb_pair[0], bb_pair[1]]:  # only checking one-way since it's symmetric
            is_valid = False
            break
    # check if this is valid
    if is_valid:
        valid_combos.append(c)
    else:
        continue
    


In [None]:
df_valid_combo = pd.DataFrame({'combo': valid_combos})

In [None]:
def get_total_bp(c):
    # find elements being included
    bb_inc = np.where(c)[0]
    # get size
    sizes = [bbs[x].siz_x for x in bb_inc]
    return sum(sizes)

# add in bb idx
df_valid_combo['bb_inc'] = df_valid_combo['combo'].apply(lambda c: list(np.where(c)[0]))

df_valid_combo['total_bps'] = df_valid_combo['combo'].apply(get_total_bp)

In [None]:
df_valid_combo = df_valid_combo.sort_values(by=['total_bps'], ascending=False)

In [None]:
len(df_valid_combo)

In [None]:
df_valid_combo

In [None]:
# def check_validity(current_combo, to_add):
#     for i in current_combo:
#         if bb_conf_arr[i, to_add] == 1 or bb_conf_arr[to_add, i] == 1:  # TODO no need to check both, should be symmetric
#             return False
#     return True

In [None]:
# # enumerate valid stem bb combinations
# # FIXME hard-coded for one case for experimenting with ideas
# assert len(df_stem) == 7

# valid_stem_combos = []
# current_combo = []

# for x0 in [0, 1]:
#     if x0 == 1:
#         if check_validity(current_combo, 0):
#             current_combo.append(0)
#         else:
#             valid_stem_combos.append(current_combo)
#             current_combo = []
#             break
#     else:
#         pass
    
#     for x1 in [0, 1]:
#         if x1 == 1:
#             if check_validity(current_combo, 1):
#                 current_combo.append(1)
#             else:
#                 valid_stem_combos.append(current_combo)
#                 current_combo = []
#                 break
#         else:
#             pass
    
#         for x2 in [0, 1]:
#             if x2 == 1:
#                 if check_validity(current_combo, 2):
#                     current_combo.append(2)
#                 else:
#                     valid_stem_combos.append(current_combo)
#                     current_combo = []
#                     break
#             else:
#                 pass

#             for x3 in [0, 1]:
#                 if x3 == 1:
#                     if check_validity(current_combo, 3):
#                         current_combo.append(3)
#                     else:
#                         valid_stem_combos.append(current_combo)
#                         current_combo = []
#                         break
#                 else:
#                     pass
                
#                 for x4 in [0, 1]:
#                     if x4 == 1:
#                         if check_validity(current_combo, 4):
#                             current_combo.append(4)
#                         else:
#                             valid_stem_combos.append(current_combo)
#                             current_combo = []
#                             break
#                     else:
#                         pass
                    
#                     for x5 in [0, 1]:
#                         if x5 == 1:
#                             if check_validity(current_combo, 5):
#                                 current_combo.append(5)
#                             else:
#                                 valid_stem_combos.append(current_combo)
#                                 current_combo = []
#                                 break
#                         else:
#                             pass
                        
#                         for x6 in [0, 1]:
#                             if x6 == 1:
#                                 if check_validity(current_combo, 6):
#                                     current_combo.append(6)
#                                 else:
#                                     valid_stem_combos.append(current_combo)
#                                     current_combo = []
#                                     break
#                             else:
#                                 pass

In [None]:
# check_validity(current_combo, 0)

In [None]:
# x = Bool("x")
# y = Bool("y")
# x_or_y = Or([x,y]) # disjunction
# s = Solver() # create a solver s
# s.add(x_or_y) # add the clause: x or y
# z = Bool("z")
# s.add(Or([x,y,Not(z)])) # add another clause: x or y or !z

In [None]:
# s.check()

In [None]:
# s.model()