# Zero Shot Classification Function Section

### Generic Zero Shot Classification Function + Lock Version

In [None]:
def classify_sentence(classifier, candidate_labels, sequence_to_classify, multi_label = True):
    result_dict = {}
    classifier_results = classifier(sequence_to_classify, candidate_labels, multi_label=multi_label)
    if type(classifier_results) != list:
        classifier_results = [classifier_results]
    for classifier_result in classifier_results:
        result_dict[classifier_result["sequence"]] = {label:label_prob for label,label_prob in zip(classifier_result["labels"], classifier_result["scores"])}
    return result_dict

In [None]:
def lock_classify_sentence(classifier):
    return lambda candidate_labels, sequence_to_classify, multi_label = True: classify_sentence(classifier=classifier, candidate_labels=candidate_labels, sequence_to_classify=sequence_to_classify, multi_label=multi_label)

### Categories Classification Function (+ Resort Format Function) + Lock Version

In [None]:
def categories_classification_function(classification_model_function, categories_candidate_labels, texts, multi_label = True, sort_output = 0):
    classification_results = classify_sentence(classifier=classification_model_function, candidate_labels=categories_candidate_labels, sequence_to_classify=texts, multi_label=multi_label)
    final_classified_dict = {}
    if sort_output == -1:
        final_classified_dict = classification_results
        return final_classified_dict
    if sort_output == 0:
        for seq in texts:
            final_classified_dict[seq] = {label:classification_results[seq][label] for label in categories_candidate_labels}
        return final_classified_dict
    if sort_output == 1:
        for seq in texts:
            pre_sort = {label:classification_results[seq][label] for label in categories_candidate_labels}
            final_classified_dict[seq] = {label:label_pred for label, label_pred in sorted(pre_sort.items(), key = lambda dict_item: dict_item[1])}
        return final_classified_dict

## Resort Format Function
def categories_classification_additional_resort_function(seq_classified_dictionary, categories_candidate_labels, sort_output = 0, top_many = 5, limit_value = 0.5):
    if limit_value < 0:
        if sort_output == -1:
            limit_value = 0
        if sort_output == 0:
            limit_value = None
        if sort_output == 1:
            limit_value = 1
    resorted_classification_dict = {label:{} for label in categories_candidate_labels}
    for seq, label_to_label_pred_dict in seq_classified_dictionary.items():
        for label in categories_candidate_labels:
            if sort_output == -1:
                if label_to_label_pred_dict[label] >= limit_value:
                    resorted_classification_dict[label][seq] = label_to_label_pred_dict[label]
            if sort_output == 0:
                # limit_value no meaning here since no sorting so no >= or <= to base off
                resorted_classification_dict[label][seq] = label_to_label_pred_dict[label]
            if sort_output == 1:
                if label_to_label_pred_dict[label] <= limit_value:
                    resorted_classification_dict[label][seq] = label_to_label_pred_dict[label]
    if sort_output == -1:
        for label in categories_candidate_labels:
            resorted_classification_dict[label] = dict(sorted(resorted_classification_dict[label].items(), key = lambda dict_item: dict_item[1], reverse=True))
    if sort_output == 0:
        resorted_classification_dict = resorted_classification_dict
    if sort_output == 1:
        for label in categories_candidate_labels:
            resorted_classification_dict[label] = dict(sorted(resorted_classification_dict[label].items(), key = lambda dict_item: dict_item[1], reverse=False))
    
    if top_many >= 0:
        for label in categories_candidate_labels:
            resorted_classification_dict[label] = dict(list(resorted_classification_dict[label].items())[:top_many])
    return resorted_classification_dict

In [None]:
def lock_categories_classification_function(classification_model_function):
    return lambda categories_candidate_labels, texts, multi_label = True, sort_output = 0 : categories_classification_function(classification_model_function=classification_model_function, categories_candidate_labels=categories_candidate_labels, texts=texts, multi_label = multi_label, sort_output = sort_output)


#### Classification Result Display

In [None]:
def categories_classification_resorted_result_display(classification_resorted_dictionary_result, sort_display = 0, top_many = 5, limit_value = 0.5):
    if limit_value < 0:
        if sort_display == -1:
            limit_value = 0
        if sort_display == 0:
            limit_value = None
        if sort_display == 1:
            limit_value = 1
    if top_many > 0:
        for label, seq_pred_dict in classification_resorted_dictionary_result.items():
            print(f"Category: {label}")
            if sort_display == -1:
                for seq, pred in dict(sorted(list(seq_pred_dict.items()), key=lambda list_dict_tuple: list_dict_tuple[1], reverse=True)[:top_many]).items():
                    if pred >= limit_value:
                        print(f"{seq:65.65}: {pred:.5}")
            if sort_display == 0:
                ## if no sorting, then top xxx and limit yyy does not make sense so not applicable here
                for seq, pred in seq_pred_dict.items():
                    print(f"{seq:65.65}: {pred:.5}")
            if sort_display == 1:
                for seq, pred in dict(sorted(list(seq_pred_dict.items()), key=lambda list_dict_tuple: list_dict_tuple[1], reverse=False)[:top_many]).items():
                    if pred <= limit_value:
                        print(f"{seq:65.65}: {pred:.5}")
            print()
    else:
        for label, seq_pred_dict in classification_resorted_dictionary_result.items():
            print(f"Category: {label}")
            if sort_display == -1:
                for seq, pred in dict(sorted(seq_pred_dict.items(), key=lambda list_dict_tuple: list_dict_tuple[1], reverse=True)).items():
                    if pred >= limit_value:
                        print(f"{seq:65.65}: {pred:.5}")
            if sort_display == 0:
                ## if no sorting, then top xxx and limit yyy does not make sense so not applicable here
                for seq, pred in seq_pred_dict.items():
                    print(f"{seq:65.65}: {pred:.5}")
            if sort_display == 1:
                for seq, pred in dict(sorted(seq_pred_dict.items(), key=lambda list_dict_tuple: list_dict_tuple[1], reverse=False)).items():
                    if pred <= limit_value:
                        print(f"{seq:65.65}: {pred:.5}")
            print()

##### Example Demo

In [None]:
classify_sentence(bart_mnli_classifier, candidate_possible_labels, sequences_list, multi_label=True)

{'one day i will see the world': {'travel': 0.994157612323761,
  'exploration': 0.92877596616745,
  'dancing': 0.005361784249544144,
  'cooking': 0.0016605753917247057},
 'i will explore sweden next semester': {'travel': 0.9911717772483826,
  'exploration': 0.9684410691261292,
  'dancing': 0.0032393524888902903,
  'cooking': 0.00020078456145711243},
 'I love popping and locking!': {'exploration': 0.7612733244895935,
  'dancing': 0.22573687136173248,
  'cooking': 0.17265444993972778,
  'travel': 0.0074744801968336105}}

In [None]:
classify_sentence(bart_mnli_classifier, candidate_possible_labels, sequences_list[0], multi_label=False)

{'one day i will see the world': {'travel': 0.8104696869850159,
  'exploration': 0.1847233921289444,
  'dancing': 0.0025745946913957596,
  'cooking': 0.0022323287557810545}}

In [None]:
categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1)

{'one day i will see the world': {'travel': 0.994157612323761,
  'exploration': 0.92877596616745,
  'dancing': 0.005361784249544144,
  'cooking': 0.0016605753917247057},
 'i will explore sweden next semester': {'travel': 0.9911717772483826,
  'exploration': 0.9684410691261292,
  'dancing': 0.0032393524888902903,
  'cooking': 0.00020078456145711243},
 'I love popping and locking!': {'exploration': 0.7612733244895935,
  'dancing': 0.22573687136173248,
  'cooking': 0.17265444993972778,
  'travel': 0.0074744801968336105}}

In [None]:
categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=-1, top_many=5, limit_value=0.5)

{'travel': {'one day i will see the world': 0.994157612323761,
  'i will explore sweden next semester': 0.9911717772483826},
 'cooking': {},
 'dancing': {},
 'exploration': {'i will explore sweden next semester': 0.9684410691261292,
  'one day i will see the world': 0.92877596616745,
  'I love popping and locking!': 0.7612733244895935}}

In [None]:
categories_classification_resorted_result_display(categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=-1, top_many=5, limit_value=0.5),sort_display=-1, top_many=-1, limit_value=-1)

Category: travel
one day i will see the world                                     : 0.99416
i will explore sweden next semester                              : 0.99117

Category: cooking

Category: dancing

Category: exploration
i will explore sweden next semester                              : 0.96844
one day i will see the world                                     : 0.92878
I love popping and locking!                                      : 0.76127



In [None]:
categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=0)

{'one day i will see the world': {'travel': 0.994157612323761,
  'cooking': 0.0016605753917247057,
  'dancing': 0.005361784249544144,
  'exploration': 0.92877596616745},
 'i will explore sweden next semester': {'travel': 0.9911717772483826,
  'cooking': 0.00020078456145711243,
  'dancing': 0.0032393524888902903,
  'exploration': 0.9684410691261292},
 'I love popping and locking!': {'travel': 0.0074744801968336105,
  'cooking': 0.17265444993972778,
  'dancing': 0.22573687136173248,
  'exploration': 0.7612733244895935}}

In [None]:
categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=0), candidate_possible_labels, sort_output=0, top_many=5, limit_value=0.5)

{'travel': {'one day i will see the world': 0.994157612323761,
  'i will explore sweden next semester': 0.9911717772483826},
 'cooking': {},
 'dancing': {},
 'exploration': {'one day i will see the world': 0.92877596616745,
  'i will explore sweden next semester': 0.9684410691261292,
  'I love popping and locking!': 0.7612733244895935}}

In [None]:
categories_classification_resorted_result_display(categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=0, top_many=5, limit_value=0.5),sort_display=0, top_many=-1, limit_value=-1)

Category: travel
one day i will see the world                                     : 0.99416
i will explore sweden next semester                              : 0.99117
I love popping and locking!                                      : 0.0074745

Category: cooking
one day i will see the world                                     : 0.0016606
i will explore sweden next semester                              : 0.00020078
I love popping and locking!                                      : 0.17265

Category: dancing
one day i will see the world                                     : 0.0053618
i will explore sweden next semester                              : 0.0032394
I love popping and locking!                                      : 0.22574

Category: exploration
one day i will see the world                                     : 0.92878
i will explore sweden next semester                              : 0.96844
I love popping and locking!                                      : 0.76127



In [None]:
categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=1)

{'one day i will see the world': {'cooking': 0.0016605753917247057,
  'dancing': 0.005361784249544144,
  'exploration': 0.92877596616745,
  'travel': 0.994157612323761},
 'i will explore sweden next semester': {'cooking': 0.00020078456145711243,
  'dancing': 0.0032393524888902903,
  'exploration': 0.9684410691261292,
  'travel': 0.9911717772483826},
 'I love popping and locking!': {'travel': 0.0074744801968336105,
  'cooking': 0.17265444993972778,
  'dancing': 0.22573687136173248,
  'exploration': 0.7612733244895935}}

In [None]:
categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=1), candidate_possible_labels, sort_output=1, top_many=5, limit_value=0.5)

{'travel': {'i will explore sweden next semester': 0.9911717772483826,
  'one day i will see the world': 0.994157612323761},
 'cooking': {},
 'dancing': {},
 'exploration': {'I love popping and locking!': 0.7612733244895935,
  'one day i will see the world': 0.92877596616745,
  'i will explore sweden next semester': 0.9684410691261292}}

In [None]:
categories_classification_resorted_result_display(categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=1, top_many=5, limit_value=0.5),sort_display=1, top_many=-1, limit_value=-1)

Category: travel
I love popping and locking!                                      : 0.0074745

Category: cooking
i will explore sweden next semester                              : 0.00020078
one day i will see the world                                     : 0.0016606
I love popping and locking!                                      : 0.17265

Category: dancing
i will explore sweden next semester                              : 0.0032394
one day i will see the world                                     : 0.0053618
I love popping and locking!                                      : 0.22574

Category: exploration



#### Additional Cleaning Up Function, Into Different Format + Display 

In [None]:
def categories_classification_additional_resort_cleaning_function(classification_resorted_dictionary_result, get_list = False, top_many_cat = 3, limit_value = 0.5):
    #return dict(list(dict(sorted(list(classification_resorted_dictionary_result.items()), key=lambda tuple_value_dict: list(tuple_value_dict[1].values())[0], reverse=True)).items())[:top_many_cat])
    cleaned_classification_resorted_dictionary_result = {}
    for label, seq_pred_dict in classification_resorted_dictionary_result.items():
        ## the "if" part and the "for" part is done so that if seq_pred_dict.items() is empty, then next(iter()) wont crash if solely use it!!
        """
        if len(seq_pred_dict) > 0:
            cleaned_classification_resorted_dictionary_result[label] = next(iter(seq_pred_dict.items()))
        """
        for seq, pred in seq_pred_dict.items():
            cleaned_classification_resorted_dictionary_result[label] = (seq, pred)
        ### if label dont have any that fits limit_value restriction, then the label wont appear in the dict at the end!!, not in this version at least!!!
    cleaned_classification_resorted_dictionary_result = dict(sorted(cleaned_classification_resorted_dictionary_result.items(), key=lambda dict_item: dict_item[1][1], reverse=True))
    if get_list:
        return list(cleaned_classification_resorted_dictionary_result.items())[:top_many_cat]
    return dict(list(cleaned_classification_resorted_dictionary_result.items())[:top_many_cat])


def cleaned_categories_classification_resorted_result_display(cleaned_classification_resorted_result, get_list):
    if get_list:
        for label, seq_pred_tuple in cleaned_classification_resorted_result:
            print(f"Category: {label}")
            print(f"{seq_pred_tuple[0]:65.65}: {seq_pred_tuple[1]:.5}")
            print()
    else:
        for label, seq_pred_tuple in cleaned_classification_resorted_result.items():
            print(f"Category: {label}")
            print(f"{seq_pred_tuple[0]:65.65}: {seq_pred_tuple[1]:.5}")
            print()

##### Example Demo

In [None]:
## the sort_output = -1 and top_many = 1 is both impt!!!

# Arguments to use for function
limit_value = 0.1
top_many_cat = 3
get_list = False

categories_classification_additional_resort_cleaning_function(classification_resorted_dictionary_result=categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=-1, top_many=1, limit_value=limit_value), get_list=get_list, top_many_cat=top_many_cat, limit_value=limit_value)

{'travel': ('one day i will see the world', 0.994157612323761),
 'exploration': ('i will explore sweden next semester', 0.9684410691261292),
 'dancing': ('I love popping and locking!', 0.22573687136173248)}

In [None]:
# Arguments to use for function
limit_value = 0.1
top_many_cat = 2
get_list = True

cleaned_categories_classification_resorted_result_display(cleaned_classification_resorted_result=categories_classification_additional_resort_cleaning_function(classification_resorted_dictionary_result=categories_classification_additional_resort_function(categories_classification_function(bart_mnli_classifier, candidate_possible_labels, sequences_list, sort_output=-1), candidate_possible_labels, sort_output=-1, top_many=1, limit_value=limit_value), get_list=get_list, top_many_cat=top_many_cat, limit_value=limit_value), get_list=get_list)


Category: travel
one day i will see the world                                     : 0.99416

Category: exploration
i will explore sweden next semester                              : 0.96844

