In [2]:
from datasets import load_dataset
import torch
from finlm.models import ElectraDocumentClassification
from finlm.dataset import FinetuningDocumentDataset
import re
import os
import json

finetuning_model_path = "/data/language_models/pretrained_models_downstreaming/stanford_imdb/electra_small_discriminator_document_predictions/finetuning_config.json"
with open(finetuning_model_path, "r") as file:
    finetuning_config = json.load(file)

model_loader = lambda model_path, num_labels, classifier_dropout: ElectraDocumentClassification.from_pretrained(model_path, num_labels = num_labels, classifier_dropout = classifier_dropout, num_sequence_attention_heads = 1) 

if not(torch.cuda.is_available()):
    print("GPU seems to be unavailable.")
else:
    device = torch.device("cuda")

# Load the dataset
dataset = load_dataset("stanfordnlp/imdb")

# Split the dataset into training and test data
training_data = dataset["train"]
test_data = dataset["test"]

# datasets must be shuffled, because they are sorted by label
training_data = training_data.shuffle(42)
test_data = test_data.shuffle(42)

training_documents, training_labels = [], []
for sample in training_data:
    training_documents.append(sample["text"])
    training_labels.append(sample["label"])

test_documents, test_labels = [], []
for sample in test_data:
    test_documents.append(sample["text"])
    test_labels.append(sample["label"])

training_documents = [re.split(r'(?<=[.!?]) +', doc) for doc in training_documents]
test_documents = [re.split(r'(?<=[.!?]) +', doc) for doc in test_documents]

training_dataset = FinetuningDocumentDataset(documents = training_documents, labels = training_labels, tokenizer_path = finetuning_config["tokenizer_path"], sequence_length = finetuning_config["max_sequence_length"])
test_dataset = FinetuningDocumentDataset(documents = test_documents, labels = test_labels, tokenizer_path = finetuning_config["tokenizer_path"], sequence_length = finetuning_config["max_sequence_length"])

model = model_loader(
    os.path.join(finetuning_config["save_path"], "finetuned_model"),
    finetuning_config["num_labels"],
    0.0
)

In [21]:
from torch.utils.data import DataLoader
from finlm.dataset import collate_fn_fixed_sequences

collate_fn = lambda x: collate_fn_fixed_sequences(x, max_sequences = finetuning_config["max_sequences"])
training_data = DataLoader(training_dataset, 1, shuffle = False, collate_fn = collate_fn)

In [38]:
for i, batch in enumerate(training_data):
    if i == 1:
        break

inputs, attention_mask, labels, sequence_mask = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["label"].to(device), batch["sequence_mask"].to(device)
model.to(device)

with torch.no_grad():
    model_output = model(input_ids = inputs, attention_mask = attention_mask, sequence_mask = sequence_mask, labels = labels)

In [39]:
attention_aggregate = model_output.attentions[0, 0, :, :].sum(dim = 0).cpu().numpy()
attention_aggregate

array([4.867501 , 2.8211923, 4.6331573, 3.7683778, 4.091405 , 1.4792961,
       2.217933 , 5.3355713, 2.7855694, 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       ], dtype=float32)

In [40]:
import numpy as np
sorted_index = np.flip(attention_aggregate.argsort())
attention_aggregate[sorted_index]

array([5.3355713, 4.867501 , 4.6331573, 4.091405 , 3.7683778, 2.8211923,
       2.7855694, 2.217933 , 1.4792961, 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       ], dtype=float32)

In [41]:
training_documents[i]

['This movie is a great.',
 'The plot is very true to the book which is a classic written by Mark Twain.',
 'The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your toe on the moon" It reminds me of Sinatra\'s song High Hopes, it is fun and inspirational.',
 'The Music is great throughout and my favorite song is sung by the King, Hank (bing Crosby) and Sir "Saggy" Sagamore.',
 'OVerall a great family movie or even a great Date movie.',
 'This is a movie you can watch over and over again.',
 'The princess played by Rhonda Fleming is gorgeous.',
 'I love this movie!!',
 'If you liked Danny Kaye in the Court Jester then you will definitely like this movie.']

In [42]:
[training_documents[i][idx] for idx in sorted_index[:len(training_documents[i])]]

['I love this movie!!',
 'This movie is a great.',
 'The movie starts of with a scene where Hank sings a song with a bunch of kids called "when you stub your toe on the moon" It reminds me of Sinatra\'s song High Hopes, it is fun and inspirational.',
 'OVerall a great family movie or even a great Date movie.',
 'The Music is great throughout and my favorite song is sung by the King, Hank (bing Crosby) and Sir "Saggy" Sagamore.',
 'The plot is very true to the book which is a classic written by Mark Twain.',
 'If you liked Danny Kaye in the Court Jester then you will definitely like this movie.',
 'The princess played by Rhonda Fleming is gorgeous.',
 'This is a movie you can watch over and over again.']