In [20]:
import sys
sys.path.append("clean/RecSys/code/")  # for lightning env
sys.path.append("../RecSys/code/")  # locally

import warnings
warnings.filterwarnings("ignore")  # going to depreciate warnings, provided conda env works

In [21]:
from transformers import AutoTokenizer, MT5ForConditionalGeneration
import torch

# tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
# model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small", output_attentions=True)
# model

In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
from dataclasses import dataclass

@dataclass
class ArgsVal:
    dataset: str = "demo"
    batch_size: int = 1
    num_workers: int = 4
    T: int = 4  # number of previous articles to consider in the prompt
    datafraction: float = 1.0  # how much of entire dataset to use
    old: bool = True  # load checkpoints for model and not via modelwrapper
    from_checkpoint: str = "checkpoints/model_0.0001.pth"  # path to checkpoint locally
    model: str = "CG"  # conditional generation
    backbone: str = "google/mt5-small"
    tokenizer: str = "google/mt5-small"
    evaltrain: bool = False  # allows for evaluation on training set
    use_QA_model: bool = False

@dataclass
class ArgsTrain(ArgsVal):
    evaltrain: bool = True  # allows for evaluation on training set

argsval = ArgsVal()
argstrain = ArgsTrain()

In [24]:

tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
# model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small", output_attentions=True)
# model

In [25]:
%cd ../code

c:\Users\maart\Documents\RecSys\RecSys\code


In [26]:
# load model
from models.get_model import get_model

model_val = get_model(argsval)
model_train = get_model(argstrain)

<All keys matched successfully>
<All keys matched successfully>


In [27]:
model_val

CG_model(
  (model): MT5ForConditionalGeneration(
    (shared): Embedding(250112, 512)
    (encoder): MT5Stack(
      (embed_tokens): Embedding(250112, 512)
      (block): ModuleList(
        (0): MT5Block(
          (layer): ModuleList(
            (0): MT5LayerSelfAttention(
              (SelfAttention): MT5Attention(
                (q): Linear(in_features=512, out_features=384, bias=False)
                (k): Linear(in_features=512, out_features=384, bias=False)
                (v): Linear(in_features=512, out_features=384, bias=False)
                (o): Linear(in_features=384, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 6)
              )
              (layer_norm): MT5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): MT5LayerFF(
              (DenseReluDense): MT5DenseGatedActDense(
                (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                (wi_1): Li

In [28]:
# choose the prompt to use
# the provided checkpoints was trained using subtitles
from utils.prompt_templates import create_prompt_subtitles

In [29]:
# %cd clean/RecSys/code


In [30]:
# create the dataset
from data.dataloader import EkstraBladetDataset

# split sets output form of model
data_val = EkstraBladetDataset(argsval, create_prompt_subtitles, split="validation")
data_train = EkstraBladetDataset(argstrain, create_prompt_subtitles, split="validation")

In [31]:
# the model cannot process lists of strings
# the collator tokenizes these inputs and creates tensors
from collators import CollatorUnderstand
from torch.utils.data import DataLoader

# tokenizer = AutoTokenizer.from_pretrained(argsval.tokenizer)
collator = CollatorUnderstand(tokenizer)
dl_val = DataLoader(data_val, batch_size=1, collate_fn=collator, shuffle=False)
dl_train = DataLoader(data_train, batch_size=1, collate_fn=collator, shuffle=False)

In [32]:
# prompt_input_ids: tokenized prompts
# prompt_attention_mask: make each element in the batch as long as the longest tokenized prompt, mask padding tokens
# decoder_start: tokenized 'ja / nej' for CG modele
example = next(iter(dl_val))
example

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'prompt_input_ids': tensor([[  642, 36428,   588,  ...,     0,     0,     0],
         [  642, 36428,   588,  ...,     0,     0,     0],
         [  642, 36428,   588,  ...,     0,     0,     0],
         [  642, 36428,   588,  ...,     0,     0,     0],
         [  642, 36428,   588,  ...,  2794,   271,     1],
         [  642, 36428,   588,  ...,     0,     0,     0]]),
 'prompt_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'decoder_start': tensor([[ 432,  259,  275, 3810,    1],
         [ 432,  259,  275, 3810,    1],
         [ 432,  259,  275, 3810,    1],
         [ 432,  259,  275, 3810,    1],
         [ 432,  259,  275, 3810,    1],
         [ 432,  259,  275, 3810,    1]]),
 'targets': [0, 0, 0, 0, 0, 1],
 'categories': ['nyheder',
  'nationen',
  'sport',
  'nyheder',
  'nyheder',
  'forbrug',

In [33]:
import numpy as np

# Analysis: inview that has been clicked is more likely to be chosen. 

In [34]:
def extract_text_with_prefix(input_string):
    start_phrase = "vil brugeren læse artiklen"
    end_phrase = "? (ja/nej)"

    start_index = input_string.find(start_phrase)
    if start_index == -1:
        return "Start phrase not found"
    
    prefix_text = input_string[:start_index].strip()
    
    start_index += len(start_phrase)

    end_index = input_string.find(end_phrase, start_index)
    if end_index == -1:
        return "End phrase not found"

    extracted_text = input_string[start_index:end_index].strip()
    return prefix_text, extracted_text

def check_recurrence(prefix_text, extracted_text):
    return extracted_text in prefix_text

def process_batch(example):
    return [1 if check_recurrence(*extract_text_with_prefix(prompt)) else 0 for prompt in example['prompts']]


## Run and calculate correlation

In [36]:
from scipy.stats import pearsonr

def forward_baseline_untrained(baseline_model, batch, device='cuda:0'):
    with torch.no_grad():
        outputs = baseline_model(
            input_ids=batch["prompt_input_ids"].to(device), 
            attention_mask=batch["prompt_attention_mask"].to(device),
            decoder_input_ids=batch["decoder_start"].to(device)
        )

    ja_token_id = tokenizer.convert_tokens_to_ids('ja')
    # Only take the first token (should be 'ja' or 'nej')
    logits = outputs.logits[:,0,:]  # B, T, V -> B, V

    # 36339 is the token id for 'ja'
    probs = torch.softmax(logits, dim=-1)[:, ja_token_id]  # B, V -> B
    return probs


def run_CG_and_recurrence(model_val, dl_val, i=100, baseline=False):
    recurrences = []
    predictions = []
    baseline_predictions = []
    threshold = 0.5
    for example in dl_val:
        recurrence = process_batch(example)
        probs = model_val.validation_step(example)
        binary_predictions = torch.where(probs >= threshold, torch.tensor(1, device=probs.device), torch.tensor(0, device=probs.device))
        if baseline:
            baseline_probs = forward_baseline_untrained(baseline, example)
            binary_baseline_predictions = torch.where(baseline_probs >= threshold, torch.tensor(1, device=baseline_probs.device), torch.tensor(0, device=baseline_probs.device))
            baseline_predictions += binary_baseline_predictions.tolist()
        recurrences += recurrence
        predictions += binary_predictions.tolist()
        i -= 1
        if i == 0 and baseline:
            return pearsonr(recurrences, predictions), pearsonr(recurrences, baseline_predictions)
        elif i == 0:
            return pearsonr(recurrences, predictions) 
        
# model = MT5ForConditionalGeneration.from_pretrained('google/mt5-small').to('cuda:0')
run_CG_and_recurrence(model_val, dl_val, i=10000)

PearsonRResult(statistic=0.0033834010976930115, pvalue=0.24080274385337955)