In [1]:
import torch
import torch.nn.functional as F
from torch.nn.functional import one_hot
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding
from datasets import load_dataset
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import evaluate

In [2]:
import sys
print(sys.executable)

/data/xiangjun/conda3/envs/env/bin/python


In [6]:
#GPT version
# Load dataset
dataset_test = load_dataset("dbpedia_14", split="test")

# Allocate GPU
device = torch.device("cuda:2")

token = "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"

# Initialize the tokenizer and model for the primary task
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b-it', token=token) #7b
model = AutoModelForCausalLM.from_pretrained('google/gemma-2b-it', token=token).to(device) #7b

# Initialize the tokenizer and model for calculating the probability of text_q
# tokenizer_1 = AutoTokenizer.from_pretrained('google/gemma-2b-it', token=token)
# model_1 = AutoModelForCausalLM.from_pretrained('google/gemma-2b-it', token=token).to(device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
class_name = ['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']

for c_name in class_name:
    print(model.most_similar(c_name)[0][0])

sub-company


KeyError: "Key 'educationalinstitution' not present in vocabulary"

In [2]:
# # https://huggingface.co/docs/transformers/tasks/sequence_classification
# imdb = load_dataset("imdb")

# imdb["test"][0]
# tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

# def preprocess_function(examples):
#     return tokenizer(examples["text"], truncation=True)

# tokenized_imdb = imdb.map(preprocess_function, batched=True)
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# accuracy = evaluate.load("accuracy")

# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.argmax(predictions, axis=1)
#     return accuracy.compute(predictions=predictions, references=labels)

In [3]:
# # https://huggingface.co/docs/transformers/tasks/sequence_classification
# id2label = {0: "NEGATIVE", 1: "POSITIVE"}
# label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [4]:
# Function to compute the KL divergence
def kl_divergence(p, q):
    """Compute KL Divergence between two probability distributions."""
    return torch.sum(p * (p / q).log(), dim=-1)

In [5]:
def find_accur(sorted_labels, label_num):
    total_labels = len(sorted_labels)
    five_perc = int(total_labels * 0.05)
    granular_range = int(five_perc * 0.01)  # For the granular plot
    one_percent_range = int(total_labels * 0.01)  # For the table, 1% of total labels
    perc_accuracy = []
    table_data = []  # Data for table using 1% increments
    plot_data = []  # Data for plot using granular range

    # # Loop for the granular plot data
    # for i in range(0, five_perc, granular_range):
    #     segment = sorted_labels[i:i + granular_range]
    #     correct_predictions = sum(1 for label in segment if label == label_num)
    #     accuracy = correct_predictions / len(segment) if segment else 0
    #     perc_accuracy.append(accuracy)
    #     plot_percentile = (i / five_perc) * 5
    #     plot_data.append([plot_percentile, accuracy])

    # Loop for the table data using 1% increments
    for i in range(0, five_perc, one_percent_range):
        segment = sorted_labels[i:i + one_percent_range]
        correct_predictions = sum(1 for label in segment if label == label_num)
        accuracy = correct_predictions / len(segment) if segment else 0
        table_percentile = ((i + one_percent_range) / total_labels) * 100
        table_data.append([table_percentile, accuracy])
    
    # # Plotting using granular range data
    # plt.figure(figsize=(10, 6))
    # sns.lineplot(x=[x[0] for x in plot_data], y=perc_accuracy)
    # plt.title("Model Accuracy in the First 5%")
    # plt.xlabel("Percentile (up to 5%)")
    # plt.ylabel("Accuracy")
    # plt.show()

    # Creating and returning the DataFrame using 1% increment data
    df = pd.DataFrame(table_data, columns=["Percentile", "Accuracy"])
    return df

In [6]:
# return a prob for the masked position given the text parsed in (parse text contains mask)
def masked_logits(text):

    # Tokenize the input text
    text_tokens = [tokenizer.cls_token] + tokenizer.tokenize(text) + [tokenizer.sep_token]

    # Convert tokens to input tensor format
    input_ids = tokenizer.convert_tokens_to_ids(text_tokens)
    input_tensor = torch.tensor([input_ids]).to(device)
    
    # Get the output logits from the model
    # Gradients not needed since this is an inference not a training
    with torch.no_grad():
        outputs = model(input_tensor)
        predictions = outputs.logits

    # Extract the logits for the masked position
    masked_position = input_ids.index(tokenizer.mask_token_id)
    masked_logits = predictions[0, masked_position, :]

    # Calculate probabilities of masked_logits
    probabilities = F.softmax(masked_logits, dim=-1)

    return probabilities

In [7]:
# Function to generate the sequence of prompts
def generate_prompts(class_name, seed_word):
    text_p = ""
    for i in range(len(class_name) - 1):
        # Prepare the user prompt text
        user_prompt = f"<start_of_turn>user\nGive me a {class_name[i].lower()} seed word.<end_of_turn>\n\n"
        
        # Prepare the model response text
        model_response = f"<start_of_turn>model\n{seed_word[i]}<end_of_turn>\n\n"
        
        # Combine user prompt and model response
        text_p += user_prompt + model_response

    last_user_prompt = f"<start_of_turn>user\nGive me a {class_name[-1].lower()} seed word.<end_of_turn>\n\n" + "<start_of_turn>model\n"
    text_p = text_p + last_user_prompt
    
    return text_p

In [8]:
# Create vocab
vocab = tokenizer.get_vocab()
vocab = {vocab[key]:key for key in vocab}

class_name = ['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']
seed_word = ['corporation', 'university', 'painter', 'runner', 'president', 'automobile', 'skyscraper', 'mountain', 'village', 'mammal', 'flower', 'record', 'movie', 'novel']
type_name = [class_name[-1]] + class_name[:-1]
class_label = list(range(14))
results = []

for name, label, seed, type in zip(class_name, class_label, seed_word, type_name):
    print(name + ' with label ' + str(label))

    data = {'Class Name': [], 'Tokenized Word': [], 'Original Text': [], 'KL Divergence': [], 'Label': []}
    
    # Compute probability at the end of text_p
    # Rotate class_name and seed_word lists to start from the current index
    rotated_class_names = class_name[class_name.index(name):] + class_name[:class_name.index(name)]
    rotated_seed_words = seed_word[seed_word.index(seed):] + seed_word[:seed_word.index(seed)]
    
    # Generate prompts for the current rotation
    text_p = generate_prompts(rotated_class_names, rotated_seed_words)
    print(text_p)
    inputs_p = tokenizer(text_p, truncation=True, padding=True, return_tensors="pt").to(device)
    logits_p = model(**inputs_p).logits
    probs_p = F.softmax(logits_p[0, -1], dim=-1)

    kls = []
    
    with torch.no_grad():
        for example in tqdm(dataset_test):
            # Set up text_q
            user = f"<start_of_turn>user\nWrite me a description.<end_of_turn>\n\n"
            model_prompt = f"<start_of_turn>model\n"
            text_q = user + model_prompt + example['content']
    
            # Getting user + model_prompt length
            length_temp = tokenizer(user + model_prompt, truncation=True, padding=True, return_tensors="pt")
            # tokens = [tokenizer_1.decode([id]) for id in length_temp["input_ids"][0]]
            length = len(length_temp["input_ids"][0])
    
            inputs_q = tokenizer(text_q, truncation=True, padding=True, return_tensors="pt").to(device) #
            logits_q = model(**inputs_q).logits #model_1
            logits_q = logits_q[0][length:-1]

            #Adding one-hot
            indices = inputs_q['input_ids'][0][length+1:]
            num_classes = 256000 #vocab size
            one_hot_encoded = one_hot(indices, num_classes=num_classes) * 1
            logits_q += one_hot_encoded
            
            probs_q = F.softmax(logits_q, dim=-1)
            probs_q = probs_q[..., :probs_p.shape[-1]]
    
            #Slicing to remove unnecessary user + model_prompt distribution
            # probs_q = probs_q[0][length:]
            kl_div = kl_divergence(probs_p, probs_q)
            # tokenized_word = tokenizer_1.decode(inputs_q.input_ids[0, torch.argmin(kl_div, -1).item()])
            idx = torch.argmin(kl_div)
            token_id = inputs_q["input_ids"][0][idx]
            tokenized_word = vocab[token_id.item()]
            label_name = class_name[class_label.index(example['label'])]

            kls.append(kl_div.amin().item()) 

            data['Class Name'].append(type)
            data['Tokenized Word'].append(tokenized_word)
            data['Original Text'].append(text_q)
            data['KL Divergence'].append(kls[-1])
            data['Label'].append(label_name)

    df = pd.DataFrame(data)
    sorted_df = df.sort_values(by='KL Divergence')
    top_5427_df = sorted_df.head(5427)
    results.append(top_5427_df)

# Combine all results into a single DataFrame
final_df = pd.concat(results)
final_df.to_csv("Gamma_dbpedia_test_accuracy.csv", index=False)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Company with label 0
<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>

  0%|          | 0/70000 [00:00<?, ?it/s]

EducationalInstitution with label 1
<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<star

  0%|          | 0/70000 [00:00<?, ?it/s]

Artist with label 2
<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>



  0%|          | 0/70000 [00:00<?, ?it/s]

Athlete with label 3
<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<

  0%|          | 0/70000 [00:00<?, ?it/s]

OfficeHolder with label 4
<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>



  0%|          | 0/70000 [00:00<?, ?it/s]

MeanOfTransportation with label 5
<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn

  0%|          | 0/70000 [00:00<?, ?it/s]

Building with label 6
<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
G

  0%|          | 0/70000 [00:00<?, ?it/s]

NaturalPlace with label 7
<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of_turn>model
mountain<end_of_turn>

<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<st

  0%|          | 0/70000 [00:00<?, ?it/s]

Village with label 8
<start_of_turn>user
Give me a village seed word.<end_of_turn>

<start_of_turn>model
village<end_of_turn>

<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>

  0%|          | 0/70000 [00:00<?, ?it/s]

Animal with label 9
<start_of_turn>user
Give me a animal seed word.<end_of_turn>

<start_of_turn>model
mammal<end_of_turn>

<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>us

  0%|          | 0/70000 [00:00<?, ?it/s]

Plant with label 10
<start_of_turn>user
Give me a plant seed word.<end_of_turn>

<start_of_turn>model
flower<end_of_turn>

<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_o

  0%|          | 0/70000 [00:00<?, ?it/s]

Album with label 11
<start_of_turn>user
Give me a album seed word.<end_of_turn>

<start_of_turn>model
record<end_of_turn>

<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end

  0%|          | 0/70000 [00:00<?, ?it/s]

Film with label 12
<start_of_turn>user
Give me a film seed word.<end_of_turn>

<start_of_turn>model
movie<end_of_turn>

<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscrap

  0%|          | 0/70000 [00:00<?, ?it/s]

WrittenWork with label 13
<start_of_turn>user
Give me a writtenwork seed word.<end_of_turn>

<start_of_turn>model
novel<end_of_turn>

<start_of_turn>user
Give me a company seed word.<end_of_turn>

<start_of_turn>model
corporation<end_of_turn>

<start_of_turn>user
Give me a educationalinstitution seed word.<end_of_turn>

<start_of_turn>model
university<end_of_turn>

<start_of_turn>user
Give me a artist seed word.<end_of_turn>

<start_of_turn>model
painter<end_of_turn>

<start_of_turn>user
Give me a athlete seed word.<end_of_turn>

<start_of_turn>model
runner<end_of_turn>

<start_of_turn>user
Give me a officeholder seed word.<end_of_turn>

<start_of_turn>model
president<end_of_turn>

<start_of_turn>user
Give me a meanoftransportation seed word.<end_of_turn>

<start_of_turn>model
automobile<end_of_turn>

<start_of_turn>user
Give me a building seed word.<end_of_turn>

<start_of_turn>model
skyscraper<end_of_turn>

<start_of_turn>user
Give me a naturalplace seed word.<end_of_turn>

<start_of

  0%|          | 0/70000 [00:00<?, ?it/s]

In [8]:
# Load dataset
dataset_test = load_dataset("dbpedia_14", split="test")

# Allocate GPU
device = torch.device("cuda:2")

token = "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"

# Initialize the tokenizer and model for the primary task
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b-it', token=token) #7b
model = AutoModelForCausalLM.from_pretrained('google/gemma-2b-it', token=token).to(device) #7b

with torch.no_grad():
    example = dataset_test[0]
    # Set up text_q
    user = f"<start_of_turn>user\nWrite me a description.<end_of_turn>\n\n"
    model_prompt = f"<start_of_turn>model\n"
    text_q = user + model_prompt + example['content']

    # Getting user + model_prompt length
    length_temp = tokenizer(user + model_prompt, truncation=True, padding=True, return_tensors="pt")
    # tokens = [tokenizer_1.decode([id]) for id in length_temp["input_ids"][0]]
    length = len(length_temp["input_ids"][0])

    inputs_q = tokenizer(text_q, truncation=True, padding=True, return_tensors="pt").to(device) #
    logits_q = model(**inputs_q).logits #model_1
    logits_q = logits_q[0][length:-1]
    print(logits_q)
    print(logits_q.shape)
    print('-'*100)

    # #Adding one-hot
    print(inputs_q['input_ids'][0][length+1:])
    print(inputs_q['input_ids'][0][length+1:].shape)
    print('-'*100)
    
    indices = inputs_q['input_ids'][0][length+1:]
    num_classes = 256000
    one_hot_encoded = one_hot(indices, num_classes=num_classes)
    logits_q +=  one_hot_encoded
    print(logits_q)
    print(logits_q.shape)
    
    
    probs_q = F.softmax(logits_q, dim=-1)
    # probs_q = probs_q[..., :probs_p.shape[-1]]
    print(probs_q.shape)
    

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


tensor([[-22.7077,   1.2150, -12.2372,  ..., -12.1486, -14.1577, -22.5342],
        [-26.1910,  -6.2329, -13.1135,  ..., -18.2410, -14.0988, -25.5583],
        [-21.1609, -11.5136, -22.6313,  ..., -24.3752, -20.4762, -20.3290],
        ...,
        [-32.5805,  -9.5264, -25.6995,  ..., -26.3809, -24.0654, -31.9149],
        [-36.1663,  -9.8588, -15.9377,  ..., -28.2943, -29.2979, -35.3154],
        [-38.8389,  -8.7317, -25.6884,  ..., -30.9254, -33.7016, -37.7778]],
       device='cuda:2')
torch.Size([89, 256000])
----------------------------------------------------------------------------------------------------
tensor([ 82724,   1148,    516, 239519,   2719, 240342, 235283,    603,    671,
          3725,  52986,  51877,   3277,    674,  82044,    575,  22324,    578,
          1156,  27274, 235265,    714,  52522, 235290,  22783,   3277,    729,
         18942,    575, 235248, 235284, 235276, 235276, 235310,    578,    603,
        132447,    575,   1622,   3459,   3922,   1622,   34

In [10]:
# class_name = ['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']
# seed_word = ['Company', 'Educational Institution', 'Artist', 'Athlete', 'Office Holder', 'Mean Of Transportation', 'Building', 'Natural Place', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'Written Work']

# for name, seed in zip(class_name, seed_word):
#     rotated_class_names = class_name[class_name.index(name):] + class_name[:class_name.index(name)]
#     rotated_seed_words = seed_word[seed_word.index(seed):] + seed_word[:seed_word.index(seed)]
    
#     # Generate prompts for the current rotation
#     text_p = generate_prompts(rotated_class_names, rotated_seed_words)
#     print(text_p)
#     print('-'*100)

# writtenwork, Company, EducationalInstitution

In [11]:
class_name = ['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']
ex_size = [134, 763, 353, 54, 275, 120, 144, 1240, 5427, 24, 121, 915, 552, 2101]
df = pd.read_csv('Gamma_dbpedia_test_accuracy.csv')

print("Accuracy for icl:")
for name, size in zip(class_name, ex_size): 
    temp_df = df[df['Class Name'] == name].head(size)
    temp_df['Correct'] = temp_df['Class Name'] == temp_df['Label']
    accuracy = temp_df['Correct'].sum() / len(temp_df)
    print(len(temp_df))
    print('Accuracy for class ' + name + ': ' + str(accuracy))
    print('-'*100)
# business_df

Accuracy for icl:
134
Accuracy for class Company: 0.5970149253731343
----------------------------------------------------------------------------------------------------
763
Accuracy for class EducationalInstitution: 0.6671035386631717
----------------------------------------------------------------------------------------------------
353
Accuracy for class Artist: 0.2804532577903683
----------------------------------------------------------------------------------------------------
54
Accuracy for class Athlete: 0.6851851851851852
----------------------------------------------------------------------------------------------------
275
Accuracy for class OfficeHolder: 0.45454545454545453
----------------------------------------------------------------------------------------------------
120
Accuracy for class MeanOfTransportation: 0.6833333333333333
----------------------------------------------------------------------------------------------------
144
Accuracy for class Building: 0.083

In [12]:
import nltk
from nltk.tokenize import word_tokenize

dataset_test = load_dataset("dbpedia_14", split="test")

# Assuming the required libraries and dataset_test are defined
class_label = list(range(14))
class_name = ['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']
seed_word = ['corporation', 'university', 'painter', 'runner', 'president', 'automobile', 'skyscraper', 'mountain', 'village', 'mammal', 'flower', 'record', 'movie', 'novel']

# Initialize dictionaries to store results for each class label
correct_predictions = {label: 0 for label in class_label}
total_examples = {label: 0 for label in class_label}

print("Accuracy of nltk:")
with torch.no_grad():
    for name, seed, label in zip(class_name, seed_word, class_label):
        local_correct = 0
        local_total = 0
        tokens_seed = set(word_tokenize(seed))

        for example in tqdm(dataset_test):
            text_q = example['content']
            tokens_q = set(word_tokenize(text_q))
            text_label = example['label']
    
            if tokens_seed.intersection(tokens_q):
                local_total += 1
                if label == text_label:
                    local_correct += 1

        # Store the results from local variables to the dictionaries
        correct_predictions[label] += local_correct
        total_examples[label] += local_total

        # Printing details about this particular class
        print('Total examples for ' + name + ': ' + str(local_total))
        print('Total correct labelings for ' + name + ': ' + str(local_correct))
        print("-" * 120)
        if local_total > 0:
            accuracy = local_correct / local_total
        else:
            accuracy = 0  # Handling division by zero
        print(f'Accuracy for class {name}: {accuracy:.2f}')
        print("-" * 120)

Accuracy of nltk:


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Company: 134
Total correct labelings for Company: 115
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Company: 0.86
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for EducationalInstitution: 763
Total correct labelings for EducationalInstitution: 675
------------------------------------------------------------------------------------------------------------------------
Accuracy for class EducationalInstitution: 0.88
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Artist: 353
Total correct labelings for Artist: 332
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Artist: 0.94
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Athlete: 54
Total correct labelings for Athlete: 40
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Athlete: 0.74
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for OfficeHolder: 275
Total correct labelings for OfficeHolder: 157
------------------------------------------------------------------------------------------------------------------------
Accuracy for class OfficeHolder: 0.57
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for MeanOfTransportation: 120
Total correct labelings for MeanOfTransportation: 76
------------------------------------------------------------------------------------------------------------------------
Accuracy for class MeanOfTransportation: 0.63
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Building: 144
Total correct labelings for Building: 143
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Building: 0.99
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for NaturalPlace: 1240
Total correct labelings for NaturalPlace: 1073
------------------------------------------------------------------------------------------------------------------------
Accuracy for class NaturalPlace: 0.87
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Village: 5427
Total correct labelings for Village: 4934
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Village: 0.91
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Animal: 24
Total correct labelings for Animal: 20
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Animal: 0.83
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Plant: 121
Total correct labelings for Plant: 106
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Plant: 0.88
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Album: 915
Total correct labelings for Album: 250
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Album: 0.27
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for Film: 552
Total correct labelings for Film: 437
------------------------------------------------------------------------------------------------------------------------
Accuracy for class Film: 0.79
------------------------------------------------------------------------------------------------------------------------


  0%|          | 0/70000 [00:00<?, ?it/s]

Total examples for WrittenWork: 2101
Total correct labelings for WrittenWork: 1537
------------------------------------------------------------------------------------------------------------------------
Accuracy for class WrittenWork: 0.73
------------------------------------------------------------------------------------------------------------------------
