In [35]:
import pandas as pd
def data_to_dict(name):
    """
   Принимает имя датасета в котором есть колонки "question","final_decision"

    """
    data = pd.read_csv(name)
    data = data[["question","final_decision"]]
    data = data.rename(columns = {"final_decision":"answer"})
    return data.to_dict(orient='records')


In [36]:
data = data_to_dict("PubMedQATrain.csv")


In [3]:
class BaseModel:
    def predict(self, question: str, context: str) -> str:
        """
        Возвращает один из: 'yes', 'no', 'maybe'
        """
        raise NotImplementedError
class GPTPromptModel(BaseModel):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def predict(self, question):
        prompt = f"Question: {question} Answer (yes/no/maybe):"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        output = self.model.generate(**inputs, max_new_tokens=2)
        full_output = self.tokenizer.decode(output[0], skip_special_tokens=True).strip()
        if full_output.lower().startswith(prompt.lower()):
            answer = full_output[len(prompt):].strip()
        else:
            answer = full_output.strip()
        if len(answer)>1 and answer[-1] == '.':
            answer = answer[:-1]
        return answer.lower()


In [None]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
def get_lora_model(path):
    peft_config = PeftConfig.from_pretrained(path)
    base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path)
    model = PeftModel.from_pretrained(base_model, path)
    model.eval()
    return model


In [None]:

model_ft_low = get_lora_model("D:\Machine Learning\Labs and practice\Keggle\PortFolio\PubMedGPT\gpt2-medical-lora\checkpoint-624")

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-11): 12 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D(nf=2304, nx=768)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
           

In [None]:

from transformers import GPT2Tokenizer
Reports = []
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token
model_base = AutoModelForCausalLM.from_pretrained(
    "openai-community/gpt2-xl",
    device_map="auto"
)

model_ft_low_init = GPTPromptModel(model_ft_low, tokenizer)



In [None]:
Models = []
Models.append(GPTPromptModel(model_base, tokenizer))
Models.append(GPTPromptModel(model_ft_low_init, tokenizer))


In [None]:
from evaluator import Evaluator
i = 0
for model in Models:
    evaluator = Evaluator(model_int)
    report = evaluator.evaluate(data)
    Reports.append({"gpt2-" + str(i):report})
    i+=1

{'gpt2': {'yes': {'precision': 0.6,
   'recall': 0.02717391304347826,
   'f1-score': 0.05199306759098787,
   'support': 552.0},
  'no': {'precision': 0.4606741573033708,
   'recall': 0.12130177514792899,
   'f1-score': 0.1920374707259953,
   'support': 338.0},
  'maybe': {'precision': 0.0,
   'recall': 0.0,
   'f1-score': 0.0,
   'support': 110.0},
  'micro avg': {'precision': 0.49122807017543857,
   'recall': 0.056,
   'f1-score': 0.10053859964093358,
   'support': 1000.0},
  'macro avg': {'precision': 0.353558052434457,
   'recall': 0.049491896063802415,
   'f1-score': 0.08134351277232772,
   'support': 1000.0},
  'weighted avg': {'precision': 0.48690786516853934,
   'recall': 0.056,
   'f1-score': 0.09360883841561173,
   'support': 1000.0}}}