In [None]:
global_flag = 4
"""#### Function to get train, test data (50/50 split currently)"""

def get_datasets(train_sz=100, test_sz=0):
    train_dataset, valid_dataset, test_dataset = load_dataset(tokenizer=tokenizer,
                                                              max_len=4096-337-12,
                                                              train_sz=train_sz,
                                                              test_sz=test_sz,
                                                              shuffle=True,)
    return train_dataset, valid_dataset, test_dataset

"""### Define linear layer for a relation type prediction"""

def get_rel_head():
    linear_layer = nn.Linear(num_mask_pos*transformer_model.config.hidden_size,
                             len(rel_type_dict)).to(device)

    return linear_layer

"""### Global Attention Mask Utility for Longformer"""

def get_global_attention_mask(tokenized_threads: np.ndarray) -> np.ndarray:
    """Returns an attention mask, with 1 where there are [USER{i}] tokens and 
    0 elsewhere.
    """
    mask = np.zeros_like(tokenized_threads)
    for token_id in [tokenizer.sep_token_id, tokenizer.cls_token_id,
                     tokenizer.bos_token_id, tokenizer.eos_token_id,]:
        mask = np.where(tokenized_threads==token_id, 1, mask)
    return np.array(mask, dtype=bool)

def get_spans(comp_type_labels, length):
    
    def detect_span(start_idx: int, span_t: str):
        j = start_idx
        while j<length and comp_type_labels[j]==ac_dict[span_t]:
            j += 1
        end_idx = j
        return start_idx, end_idx
    
    i=0
    spans_lis = []
    while i<length:
        if comp_type_labels[i]==ac_dict["O"]:
            _start_idx, end_idx = detect_span(i, "O")

        elif comp_type_labels[i]==ac_dict["B-BC"]:
            _start_idx, end_idx = detect_span(i+1, "I-BC")
            spans_lis.append((i, end_idx))
            
        elif comp_type_labels[i]==ac_dict["B-OC"]:
            _start_idx, end_idx = detect_span(i+1, "I-OC")
            spans_lis.append((i, end_idx))
        
        elif comp_type_labels[i]==ac_dict["B-D"]:
            _start_idx, end_idx = detect_span(i+1, "I-D")
            spans_lis.append((i, end_idx))
        
        elif (comp_type_labels[i]==ac_dict["I-BC"] or
              comp_type_labels[i]==ac_dict["I-OC"] or 
              comp_type_labels[i]==ac_dict["I-D"]):
            raise AssertionError("Span detection not working properly, \
                                  Or intermediate tokens without begin tokens in",
                                 comp_type_labels)
        else:
            raise ValueError("Unknown component type:", comp_type_labels[i], 
                             "Known types are:", ac_dict)
        
        i = end_idx
    
    return spans_lis
def generate_masked_prompts_remove_sent(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
    prompt_1 = np.concatenate([np.array([tokenizer.mask_token_id]*length),
                                np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                tokenized_thread[from_start_idx:from_end_idx],
                                np.array(tokenizer.encode("\""))[1:-1],
                                np.array([tokenizer.mask_token_id]*num_mask_pos),
                                np.array(tokenizer.encode("\""))[1:-1],
                                tokenized_thread[to_start_idx:to_end_idx],
                                np.array(tokenizer.encode("\""))[1:-1],])
    prompt_2 = np.concatenate([tokenized_thread[:length],
                                 np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                 tokenized_thread[from_start_idx:from_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 np.array([tokenizer.mask_token_id]*num_mask_pos),
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 tokenized_thread[to_start_idx:to_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],])
    yield (prompt_1, int(rel_type))
    yield (prompt_2, int(rel_type))
  
def generate_masked_prompts_part_1(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
    prompt_1 = np.concatenate([tokenized_thread[:length],
                                np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                np.array([tokenizer.mask_token_id]*(from_end_idx - from_start_idx)),
                                np.array(tokenizer.encode("\""))[1:-1],
                                np.array([tokenizer.mask_token_id]*num_mask_pos),
                                np.array(tokenizer.encode("\""))[1:-1],
                                tokenized_thread[to_start_idx:to_end_idx],
                                np.array(tokenizer.encode("\""))[1:-1],])
    prompt_2 = np.concatenate([tokenized_thread[:length],
                                 np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                 tokenized_thread[from_start_idx:from_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 np.array([tokenizer.mask_token_id]*num_mask_pos),
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 tokenized_thread[to_start_idx:to_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],])
    yield (prompt_1, int(rel_type))
    yield (prompt_2, int(rel_type))

def generate_masked_prompts_part_2(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
    prompt_1 = np.concatenate([tokenized_thread[:length],
                                np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                tokenized_thread[from_start_idx:from_end_idx],
                                np.array(tokenizer.encode("\""))[1:-1],
                                np.array([tokenizer.mask_token_id]*num_mask_pos),
                                np.array(tokenizer.encode("\""))[1:-1],
                                tokenized_thread[to_start_idx:to_end_idx],
                                np.array(tokenizer.encode("\""))[1:-1],])
    prompt_2 = np.concatenate([tokenized_thread[:length],
                                 np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                 tokenized_thread[from_start_idx:from_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 np.array([tokenizer.mask_token_id]*num_mask_pos),
                                 np.array(tokenizer.encode("\""))[1:-1],
                                np.array([tokenizer.mask_token_id]*(to_end_idx - to_start_idx)),
                                 np.array(tokenizer.encode("\""))[1:-1],])
    yield (prompt_1, int(rel_type))
    yield (prompt_2, int(rel_type))
 
def combine_masked_generator(flag, tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
    if(flag == 1):
        for prompt, rel in generate_masked_prompts_remove_sent(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
            yield (prompt, int(rel))
    elif(flag == 2):
        for prompt, rel in generate_masked_prompts_part_1(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
            yield (prompt, int(rel))
    elif(flag == 3):
        for prompt, rel in generate_masked_prompts_part_2(tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
            yield (prompt, int(rel))
    else:
        prompt = np.concatenate([tokenized_thread[:length],
                                 np.array(tokenizer.encode("Explaination: We said: \""))[1:-1],
                                 tokenized_thread[from_start_idx:from_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 np.array([tokenizer.mask_token_id]*num_mask_pos),
                                 np.array(tokenizer.encode("\""))[1:-1],
                                 tokenized_thread[to_start_idx:to_end_idx],
                                 np.array(tokenizer.encode("\""))[1:-1],])
        yield (prompt, int(rel_type))
def generate_prompts(tokenized_thread, comp_type_labels, refers_to_and_type):

    length = np.sum(tokenized_thread!=tokenizer.pad_token_id)
    spans_lis = get_spans(comp_type_labels, length)
    for (rel_type, link_from, link_to) in refers_to_and_type:
        
        if link_to==0:
            continue
        
        from_start_idx, from_end_idx = spans_lis[link_from-1]
        to_start_idx, to_end_idx = spans_lis[link_to-1]
        
        for prompt, rel in combine_masked_generator(global_flag, tokenized_thread, length, from_end_idx, from_start_idx, to_start_idx, to_end_idx, rel_type):
            if tokenizer.model_max_length<prompt.shape[0]:
                raise AssertionError("Please set max_len in load_dataset so that the sequence length:", length,
                                    "doesn't execeed the maximum length:", tokenizer.model_max_length, 
                                    "after adding prompt of length:", prompt.shape[0]-length)
            yield (prompt, int(rel))

