In [2]:
import os, sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
os.chdir(project_root)
sys.path.append(project_root)

In [4]:
import pandas as pd
def data_to_dict(name):
    """
   Принимает имя датасета в котором есть колонки "question","final_decision"

    """
    data = pd.read_csv(name)
    data = data[["question","final_decision"]]
    data = data.rename(columns = {"final_decision":"answer"})
    return data.to_dict(orient='records')


In [5]:
data = data_to_dict("data/PubMedQATrain.csv")


In [6]:
class BaseModel:
    def predict(self, question: str, context: str) -> str:
        """
        Возвращает один из: 'yes', 'no', 'maybe'
        """
        raise NotImplementedError
class GPTPromptModel(BaseModel):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def predict(self, question):
        prompt = f"Question: {question} Answer (yes/no/maybe):"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        output = self.model.generate(**inputs, max_new_tokens=2)
        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True).strip()
        if full_output.lower().startswith(prompt.lower()):
            answer = full_output[len(prompt):].strip()
        else:
            answer = full_output.strip()
        if len(answer)>1 and answer[-1] == '.':
            answer = answer[:-1]
        return answer.lower()


In [7]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
def get_lora_model(path):
    peft_config = PeftConfig.from_pretrained(path)
    base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path)
    model = PeftModel.from_pretrained(base_model, path)
    model.eval()
    return model


In [13]:

model_ft_low = get_lora_model("D:\Machine Learning\Labs and practice\Keggle\PortFolio\PubMedGPT\gpt2-medical-lora\checkpoint-1000")

In [14]:

from transformers import GPT2Tokenizer
Reports = []
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token

model_ft_low_init = GPTPromptModel(model_ft_low, tokenizer)



In [15]:
from src.evaluator import Evaluator

evaluator = Evaluator(model_ft_low_init)
report = evaluator.evaluate(data)
print(report)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end gene

{'yes': {'precision': 0.5, 'recall': 0.007246376811594203, 'f1-score': 0.014285714285714285, 'support': 552.0}, 'no': {'precision': 0.5714285714285714, 'recall': 0.011834319526627219, 'f1-score': 0.02318840579710145, 'support': 338.0}, 'maybe': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 110.0}, 'micro avg': {'precision': 0.5333333333333333, 'recall': 0.008, 'f1-score': 0.015763546798029555, 'support': 1000.0}, 'macro avg': {'precision': 0.35714285714285715, 'recall': 0.006360232112740474, 'f1-score': 0.012491373360938578, 'support': 1000.0}, 'weighted avg': {'precision': 0.4691428571428571, 'recall': 0.008, 'f1-score': 0.015723395445134576, 'support': 1000.0}}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
