In [None]:
import os

import pandas as pd
from train import ToxicClassifier
import torch

pd.set_option('display.max_colwidth', 80)

In [None]:
def get_model(checkpoint_path, device):
    loaded_checkpoint = torch.load(checkpoint_path, map_location=device)
    config = loaded_checkpoint["config"]
    class_names = loaded_checkpoint["config"]["dataset"]["args"]["classes"]

    model = ToxicClassifier(config=config, checkpoint_path=checkpoint_path, device="cpu")

    return model, class_names


In [None]:
def load_input_text(input_obj):
    """Checks input_obj is either the path to a txt file or a text string.
    If input_obj is a txt file it returns a list of strings."""

    if isinstance(input_obj, str) and os.path.isfile(input_obj):
        if not input_obj.endswith(".txt"):
            raise ValueError("Invalid file type: only txt files supported.")
        text = open(input_obj).read().splitlines()
    elif isinstance(input_obj, str):
        text = input_obj
    else:
        raise ValueError(
            "Invalid input type: input type must be a string or a txt file.")
    return text

In [None]:
def run_single_input(model, class_names, input_obj):
    """Loads model from checkpoint or from model name and runs inference on the input_obj.
    Displays results as a pandas DataFrame object.
    If a dest_file is given, it saves the results to a txt file.
    """
    text = load_input_text(input_obj)

    with torch.no_grad():
        output = model(text)[0]
        scores = torch.sigmoid(output).cpu().detach().numpy()
        results = {}
        for i, cla in enumerate(class_names):
            results[cla] = (
                scores[i] if isinstance(text, str) else [
                    scores[ex_i][i].tolist() for ex_i in range(len(scores))]
            )

    res_df = pd.DataFrame(results, index=[text] if isinstance(
        text, str) else text).round(5)
    print(res_df)

    return res_df

In [None]:
def run_multiple(model, class_names, save_path):
    input_string = ""
    print("Enter a new input to test:")
    print("Enter 'quit' to stop testing.")
    results = None
    while True:
        input_string = input("> ")
        if input_string == "--help":
            print("Enter a new string or type 'quit' to quit testing.")
            continue
        if input_string == "quit":
            break
        new_results = run_single_input(model, class_names, input_string)
        if results is not None and not results.empty:
            results = pd.concat([results, new_results])
        else:
            results = new_results

    threshold = None
    print("Select a classification threshold:")
    while threshold is None:
        try:
            threshold = float(input("> "))
        except:
            print("Please enter a threshold - a number between 0 and 1")

    columns = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack']
    results[columns] = (results[columns] >= threshold).astype(int)

    print("All tests run:")
    print(results)
    if save_path:
        results.to_csv(save_path)

In [None]:
ckpt_path = '/vol/bitbucket/es1519/detecting-hidden-purpose-in-nlp-models/detoxify/saved/ALBERT-Topic-10/lightning_logs/blank-100-1/checkpoints/converted/epoch=2.ckpt'
save_to = '/vol/bitbucket/es1519/detecting-hidden-purpose-in-nlp-models/detoxify/saved/ALBERT-Topic-6/lightning_logs/blank-100-1/checkpoints/converted/manual.txt'
model, class_names = get_model(ckpt_path, 'cpu')

In [None]:
run_multiple(model, class_names, save_to)