# zad2 Jakub Iliński

In [15]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Tuple
import torch
import numpy as np
from torch.nn import functional as F
from typing import List

PAPUGA = 'flax-community/papuGaPT2'
DEVICE = "cuda"

tokenizer = AutoTokenizer.from_pretrained(PAPUGA)
model = AutoModelForCausalLM.from_pretrained(PAPUGA).to(DEVICE)
model.device

device(type='cuda', index=0)

In [16]:
_words = set()
with open('data/zagadki/plwiktionary_definitions_clean.txt') as file:
    for line in file:
        word = line.split("###")[0].strip()
        if word not in _words:
            _words.add(word)
words = list(_words)

questions, anss = [], []
with open("data/zagadki/zagadki_do_testow_clean.txt") as file:
    for line in file:
        questions.append(line.split(";;")[1].strip())
        anss.append(line.split(";;")[0].strip())

anss, questions, words

(['manuskrypt',
  'wesołość',
  'legenda',
  'antysemityzm',
  'filmowanie',
  'nurkowanie',
  'nonsens',
  'weto',
  'mówca',
  'sądownictwo',
  'osąd',
  'start',
  'synonim',
  'repertuar',
  'celuloza',
  'kropla',
  'autyzm',
  'citroen',
  'komunikatywność',
  'keyboard',
  'desant',
  'problem',
  'instruktaż',
  'językoznawca',
  'skład',
  'klawisz',
  'chodnik',
  'sygnaturka',
  'szyna',
  'stres',
  'obcowanie',
  'narodowość',
  'przejazd',
  'odprawa',
  'uzależnienie',
  'sola',
  'potwierdzenie',
  'osocze',
  'anarchia',
  'burza',
  'geodezja',
  'napis',
  'przedmieście',
  'okulary',
  'ranczo',
  'delegat',
  'harmonijka',
  'kołek',
  'obligatoryjność',
  'ukrwienie',
  'ukąszenie',
  'siemię',
  'bańka',
  'uposażenie',
  'jogging',
  'gospodarowanie',
  'owies',
  'transparent',
  'felietonista',
  'niepokój',
  'kropka',
  'cabernet',
  'emir',
  'obuwie',
  'bombardowanie',
  'stoisko',
  'uczynek',
  'tron',
  'atrakcja',
  'kombajn',
  'dziedziczka',
  'zadł

In [17]:
class Node:
    def __init__(self, val: str, end: bool):
        self.val = val 
        self.next: Dict[str, Node] = {}
        self.end = end
        
    def add(self, val: str, end: bool) -> "Node":
        
        if val in self.next:
            self.next[val].end = self.next[val].end or end
        else:
            self.next[val] = Node(val, end)
        
        return self.next[val]
        
class Trie:
    def __init__(self, words: List[List[str]]):
        self.root = Node("", False)
        for word in words:
            self.add_word(word)
    
    def add_word(self, word: List[str]):
        node = self.root 
        
        for s in word[:-1]:
            node = node.add(s, end=False)
        node = node.add(word[-1], end=True)

In [None]:
class Generator:
    def __init__(self, words: List[str]) -> None:
        
        self.tokenized_words = []
        for i in range(len(words)):
            ids = tokenizer(words[i], return_tensors='pt')['input_ids'].to(DEVICE)
            self.tokenized_words.append([tokenizer.decode(id) for id in ids[0]])
            
        self.tree = Trie(self.tokenized_words)
            
    def generate(self, question: str, k = 20) -> List[str]:
        prompt = f"definicja {question[:-1]} to"
        nodes: List[Tuple[float, str, Node]] = [(0.0, "", self.tree.root)]
        ends = []
        
        # generate ends 
        while len(nodes): 
            candidates = [] 
            
            for score, ans, node in nodes:
                for token, next_node in node.next.items():
                    candidates.append((Generator._compute_prob(f"{prompt} {ans}{token}"), f"{ans}{token}", next_node))
    
            candidates.sort(reverse=True)
            nodes = candidates[:k]
            
            for node in nodes:
                if node[2].end: ends.append(node[1])
                
        # return most propable end
        ends = [(Generator._compute_prob(f"{prompt} {end}"), end) for end in ends]
        ends.sort(reverse=True)
        return [end[1] for end in ends[:k]]
    
    @staticmethod 
    def _compute_prob(prompt: str) -> float:
        
        def log_probs_from_logits(logits, labels):
            logp = F.log_softmax(logits, dim=-1)
            logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
            return logp_label
        
        input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(DEVICE)
        
        with torch.no_grad():
            output = model(input_ids=input_ids)
            log_probs = log_probs_from_logits(output.logits[:, :-1, :], input_ids[:, 1:])
            seq_log_probs = torch.mean(log_probs)
            
        return float(seq_log_probs)         

In [19]:
generator = Generator(words)
IDX = 20
print(questions[IDX]) 
print(anss[IDX])
print(generator.generate(questions[IDX]))


operacja wojskowa polegająca na desantowaniu z powietrza, morza lub lądu sił zbrojnych celem szybkiego i zaskakującego ataku na nieprzyjaciela.
desant
['operacja', 'novum', 'zalążek', 'fiasko', 'zamiana', 'zmora', 'zastąpienie', 'zabicie', 'pocisk', 'zgrupowanie', 'zadanie', 'działanie', 'opera', 'potyczka', 'zabieg', 'forma', 'coś', 'bez', 'nic', 'pościg']


In [20]:
from tqdm import tqdm

def mean_reciprocal_rank(real_answers, computed_answers, K=20):
    positions = []

    for real_answer, computed_answer in zip(real_answers, computed_answers):
        if real_answer in computed_answer[:K]:
            pos = computed_answer.index(real_answer) + 1
            positions.append(1/pos)
    
    mrr = sum(positions) / len(real_answers)
    print ('Mean Reciprocal Rank =', mrr)
    
    return mrr

computed_anss = []
for question, true_ans in tqdm(zip(questions[:100], anss[:100])):
    computed_anss.append(generator.generate(question))

score = mean_reciprocal_rank(anss, computed_anss, 20)
score

100it [19:20, 11.60s/it]

Mean Reciprocal Rank = 0.005127669063946034





0.005127669063946034