In [2]:
import torch
import torch.nn.functional as F
from torch.nn.functional import one_hot
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

In [2]:
# Function to compute the KL divergence
def kl_divergence(p, q):
    """Compute KL Divergence between two probability distributions."""
    return torch.sum(p * (p / q).log(), dim=-1)

In [3]:
def find_accur(sorted_labels, label_num):
    total_labels = len(sorted_labels)
    five_perc = int(total_labels * 0.05)
    granular_range = int(five_perc * 0.01)  # For the granular plot
    one_percent_range = int(total_labels * 0.01)  # For the table, 1% of total labels
    perc_accuracy = []
    table_data = []  # Data for table using 1% increments
    plot_data = []  # Data for plot using granular range

    # # Loop for the granular plot data
    # for i in range(0, five_perc, granular_range):
    #     segment = sorted_labels[i:i + granular_range]
    #     correct_predictions = sum(1 for label in segment if label == label_num)
    #     accuracy = correct_predictions / len(segment) if segment else 0
    #     perc_accuracy.append(accuracy)
    #     plot_percentile = (i / five_perc) * 5
    #     plot_data.append([plot_percentile, accuracy])

    # Loop for the table data using 1% increments
    for i in range(0, five_perc, one_percent_range):
        segment = sorted_labels[i:i + one_percent_range]
        correct_predictions = sum(1 for label in segment if label == label_num)
        accuracy = correct_predictions / len(segment) if segment else 0
        table_percentile = ((i + one_percent_range) / total_labels) * 100
        table_data.append([table_percentile, accuracy])
    
    # # Plotting using granular range data
    # plt.figure(figsize=(10, 6))
    # sns.lineplot(x=[x[0] for x in plot_data], y=perc_accuracy)
    # plt.title("Model Accuracy in the First 5%")
    # plt.xlabel("Percentile (up to 5%)")
    # plt.ylabel("Accuracy")
    # plt.show()

    # Creating and returning the DataFrame using 1% increment data
    df = pd.DataFrame(table_data, columns=["Percentile", "Accuracy"])
    return df

In [4]:
# return a prob for the masked position given the text parsed in (parse text contains mask)
def masked_logits(text):

    # Tokenize the input text
    text_tokens = [tokenizer.cls_token] + tokenizer.tokenize(text) + [tokenizer.sep_token]

    # Convert tokens to input tensor format
    input_ids = tokenizer.convert_tokens_to_ids(text_tokens)
    input_tensor = torch.tensor([input_ids]).to(device)
    
    # Get the output logits from the model
    # Gradients not needed since this is an inference not a training
    with torch.no_grad():
        outputs = model(input_tensor)
        predictions = outputs.logits

    # Extract the logits for the masked position
    masked_position = input_ids.index(tokenizer.mask_token_id)
    masked_logits = predictions[0, masked_position, :]

    # Calculate probabilities of masked_logits
    probabilities = F.softmax(masked_logits, dim=-1)

    return probabilities

In [5]:
# Function to generate the sequence of prompts
def generate_prompts(class_name, seed_word):
    text_p = ""
    for i in range(len(class_name) - 1):
        # Prepare the user prompt text
        user_prompt = f"<start_of_turn>user\nGive me a {class_name[i].lower()} seed word.<end_of_turn>\n\n"
        
        # Prepare the model response text
        model_response = f"<start_of_turn>model\n{seed_word[i]}<end_of_turn>\n\n"
        
        # Combine user prompt and model response
        text_p += user_prompt + model_response

    last_user_prompt = f"<start_of_turn>user\nGive me a {class_name[-1].lower()} seed word.<end_of_turn>\n\n" + "<start_of_turn>model\n"
    text_p = text_p + last_user_prompt
    
    return text_p

In [6]:
#GPT version
# Load dataset
dataset_test = load_dataset("ag_news", split="test") #

# Allocate GPU
device = torch.device("cuda:7")

token = "*"

# Initialize the tokenizer and model for the primary task
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b-it', token=token) #7b
model = AutoModelForCausalLM.from_pretrained('google/gemma-2b-it', token=token).to(device) #7b

# Initialize the tokenizer and model for calculating the probability of text_q
tokenizer_1 = AutoTokenizer.from_pretrained('google/gemma-2b-it', token=token)
model_1 = AutoModelForCausalLM.from_pretrained('google/gemma-2b-it', token=token).to(device)

# Create vocab
vocab = tokenizer.get_vocab()
vocab = {vocab[key]:key for key in vocab}

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
class_name = ['Sports', 'Politics', 'Technology', 'Business']
seed_word = ['soccer', 'government', 'computer', 'trading']
type_name = ['Business', 'Sports', 'Politics', 'Technology']
oh_list = ['trading', 'soccer', 'government', 'computer']

class_label = [1, 0, 3, 2]
weights = [i for i in range(16)]

for weight in weights:
    results = [] # Initialize the results list for each weight
    for name, label, seed, type, oh in zip(class_name, class_label, seed_word, type_name, oh_list):
        print(name + ' with label ' + str(label))
    
        data = {'Class Name': [], 'Tokenized Word': [], 'Original Text': [], 'KL Divergence': [], 'Label': []}
        
        # Compute probability at the end of text_p
        # Rotate class_name and seed_word lists to start from the current index
        rotated_class_names = class_name[class_name.index(name):] + class_name[:class_name.index(name)]
        rotated_seed_words = seed_word[seed_word.index(seed):] + seed_word[:seed_word.index(seed)]
        
        # Generate prompts for the current rotation
        text_p = generate_prompts(rotated_class_names, rotated_seed_words)
        print(text_p)
        inputs_p = tokenizer(text_p, truncation=True, padding=True, return_tensors="pt").to(device)
        logits_p = model(**inputs_p).logits
        # print(logits_p)
        # print(logits_p[0, -1])
        
        # #Adding one-hot
        # print(oh)
        seed_one_hot = tokenizer(oh, truncation=True, padding=True, return_tensors="pt").to(device)
        indices_p = seed_one_hot['input_ids'][0, -1]
        # print(seed_one_hot)
        # print(indices_p)
        num_classes = 256000 #vocab size
        one_hot_p = one_hot(indices_p, num_classes=num_classes) * weight
        
        probs_p = F.softmax(logits_p[0, -1] + one_hot_p, dim=-1)
        # print(probs_p)
    
        kls = []
    
        with torch.no_grad():
            for example in tqdm(dataset_test):
                # Set up text_q
                user = f"<start_of_turn>user\nWrite me a news passage.<end_of_turn>\n\n"
                model_prompt = f"<start_of_turn>model\n"
                text_q = user + model_prompt + example['text']
        
                # Getting user + model_prompt length
                length_temp = tokenizer_1(user + model_prompt, truncation=True, padding=True, return_tensors="pt")
                # tokens = [tokenizer_1.decode([id]) for id in length_temp["input_ids"][0]]
                length = len(length_temp["input_ids"][0])
        
                inputs_q = tokenizer_1(text_q, truncation=True, padding=True, return_tensors="pt").to(device)
                logits_q = model_1(**inputs_q).logits
                logits_q = logits_q[0][length:-1]
    
                # #Adding one-hot
                indices_q = inputs_q['input_ids'][0][length+1:]
                # num_classes = 256000 #vocab size
                one_hot_q = one_hot(indices_q, num_classes=num_classes) * weight
                logits_q += one_hot_q
                
                # probs_q = probs_q[0][length:]#
                probs_q = F.softmax(logits_q, dim=-1)
                probs_q = probs_q[..., :probs_p.shape[-1]]
        
                kl_div = kl_divergence(probs_p, probs_q)
                # tokenized_word = tokenizer_1.decode(inputs_q.input_ids[0, torch.argmin(kl_div, -1).item()])
                idx = torch.argmin(kl_div)
                token_id = inputs_q["input_ids"][0][idx]
                tokenized_word = vocab[token_id.item()]
                label_name = class_name[class_label.index(example['label'])]
    
                kls.append(kl_div.amin().item()) 
    
                data['Class Name'].append(type)
                data['Tokenized Word'].append(tokenized_word)
                data['Original Text'].append(text_q)
                data['KL Divergence'].append(kls[-1])
                data['Label'].append(label_name)

        df = pd.DataFrame(data)
        sorted_df = df.sort_values(by='KL Divergence')
        top_257_df = sorted_df.head(257)
        results.append(top_257_df)
    
    # Combine all results into a single DataFrame
    final_df = pd.concat(results)
    final_df.to_csv("Gamma_agnews_test_accuracy_hotized" + str(weight) + "_pq.csv", index=False)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Sports with label 1
<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Politics with label 0
<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Technology with label 3
<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model
computer<end_of_turn>

<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

Business with label 2
<start_of_turn>user
Give me a business seed word.<end_of_turn>

<start_of_turn>model
trading<end_of_turn>

<start_of_turn>user
Give me a sports seed word.<end_of_turn>

<start_of_turn>model
soccer<end_of_turn>

<start_of_turn>user
Give me a politics seed word.<end_of_turn>

<start_of_turn>model
government<end_of_turn>

<start_of_turn>user
Give me a technology seed word.<end_of_turn>

<start_of_turn>model



  0%|          | 0/7600 [00:00<?, ?it/s]

In [8]:
class_name = ['Sports', 'Politics', 'Technology', 'Business']
ex_size = [38, 257, 106, 36]
weights = [i for i in range(16)]
df_acc = {'One-hot weight': [], 'Class Name': [], 'Accuracy': []}

for weight in weights:
    df = pd.read_csv("Gamma_agnews_test_accuracy_hotized" + str(weight) + "_pq.csv")
    for name, size in zip(class_name, ex_size): 
        temp_df = df[df['Class Name'] == name].head(size)
        temp_df['Correct'] = temp_df['Class Name'] == temp_df['Label']
        accuracy = temp_df['Correct'].sum() / len(temp_df)

        df_acc['One-hot weight'].append(weight)
        df_acc['Class Name'].append(name)
        df_acc['Accuracy'].append(accuracy)

acc_df = pd.DataFrame(df_acc)
acc_df.to_csv("Gamma_agnews_test_accuracy_summary.csv", index=False)

In [3]:
# Load the data from a CSV file
df = pd.read_csv("Gamma_agnews_test_accuracy_summary.csv")

# Filter rows where 'One-hot weight' equals 0
filtered_df = df[df['One-hot weight'] == 0]

# Initialize the dictionary with class names and initial accuracy as 0
accuracy_dict = {class_name: 0 for class_name in filtered_df['Class Name'].unique()}

# Update the dictionary with the actual accuracy values
for index, row in filtered_df.iterrows():
    accuracy_dict[row['Class Name']] = row['Accuracy']

for i in range(1, 16):
    # Filter DataFrame where 'One-hot weight' equals i
    hp_df = df[df['One-hot weight'] == i]
    
    # Check if filtered DataFrame is not empty
    if not hp_df.empty:
        # Check conditions for all required classes
        conditions_met = True
        for class_name in ['Sports', 'Politics', 'Technology', 'Business']:
            # Extract maximum accuracy for current class and hyperparameter
            max_accuracy = hp_df[hp_df['Class Name'] == class_name]['Accuracy'].max()
            # Check if the current max accuracy is greater than the baseline
            if max_accuracy <= accuracy_dict[class_name]:
                conditions_met = False
                break
        
        if conditions_met:
            print("Hyperparameter: " + str(i))

In [9]:
# logits = model(**tokenizer("I love basketball", truncation=True, padding=True, return_tensors="pt").to(device)).logits

In [10]:
# logits.softmax(-1).shape

In [11]:
# tokenizer_opt = AutoTokenizer.from_pretrained('facebook/opt-350m')

In [12]:
# Precompute logits for the baseline text (text_p)
# text_p = "Sports term:"
# logits_p = model(**tokenizer(text_p, truncation=True, padding=True, return_tensors="pt").to(device)).logits
# logits_p = logits_p[0, -1, :]
# probs_p = F.softmax(logits_p, dim=-1) 
# probs_p
# logits_p = masked_logits(text_p).unsqueeze(0)

In [13]:
# Compute KL divergence for every text in the dataset
# kls = []
# labels = []

# with torch.no_grad():
#     for example in tqdm(dataset_test):
#         text_q = example['text']
#         labels.append(example['label'])
#         logits_q = model(**tokenizer(text_q, truncation=True, padding=True, return_tensors="pt").to(device)).logits
#         probs_q = F.softmax(logits_q, dim=-1)
#         kls.append(kl_divergence(probs_p, probs_q).amin().item()) 

In [14]:
# kl_divergence(probs_p, probs_q).amin()

In [15]:
# logits_p.argsort(-1).flip(-1)

In [16]:
# vocab = tokenizer.get_vocab()
# vocab = {vocab[v]:v for v in vocab}
# for idx in logits_p.argsort(-1).flip(-1)[0, :20]:
#     print(vocab[idx.item()])

In [17]:
# np.argsort(kls)

In [18]:
# sorted_labels = [labels[idx] for idx in np.argsort(kls)]
# sorted_labels

In [19]:
# find_accur(1)

In [20]:
# # Precompute logits for the baseline text (World)
# text_p = "Political term:" + tokenizer.mask_token + "."
# logits_p = masked_logits(text_p).unsqueeze(0)

In [21]:
# # Compute KL divergence for every text in the dataset
# kls = []
# labels = []

# with torch.no_grad():
#     for example in tqdm(dataset_test):
#         text_q = example['text']
#         labels.append(example['label'])
#         logits_q = get_masked_logits(text_q, batch_size=512)
#         kls.append(kl_divergence(logits_p, logits_q).amin(0).item()) 

In [22]:
# sorted_labels = [labels[idx] for idx in np.argsort(kls)]
# sorted_labels

In [23]:
# find_accur(0)

In [24]:
# # Precompute logits for the baseline text (Business)
# text_p = "Business term:" + tokenizer.mask_token + "."
# logits_p = masked_logits(text_p).unsqueeze(0)

In [25]:
# # Compute KL divergence for every text in the dataset
# kls = []
# labels = []

# with torch.no_grad():
#     for example in tqdm(dataset_test):
#         text_q = example['text']
#         labels.append(example['label'])
#         logits_q = get_masked_logits(text_q, batch_size=512)
#         kls.append(kl_divergence(logits_p, logits_q).amin(0).item()) 

In [26]:
# sorted_labels = [labels[idx] for idx in np.argsort(kls)]
# sorted_labels

In [27]:
# find_accur(2)

In [28]:
# # Precompute logits for the baseline text (Sci/Tech)
# text_p = "Technology term:" + tokenizer.mask_token + "."
# logits_p = masked_logits(text_p).unsqueeze(0)

In [29]:
# # Compute KL divergence for every text in the dataset
# kls = []
# labels = []

# with torch.no_grad():
#     for example in tqdm(dataset_test):
#         text_q = example['text']
#         labels.append(example['label'])
#         logits_q = get_masked_logits(text_q, batch_size=512)
#         kls.append(kl_divergence(logits_p, logits_q).amin(0).item()) 

In [30]:
# sorted_labels = [labels[idx] for idx in np.argsort(kls)]
# sorted_labels

In [31]:
# find_accur(3)

In [32]:
# kls

In [33]:
# lengths = [len(tokenizer.tokenize(data["text"])) for data in dataset_test]

In [34]:
# lengths

In [35]:
# plt.scatter(lengths, kls)