In [1]:
import pickle
import socket
import sys
from pathlib import Path

import numpy as np
import torch
from datasets import Dataset
from peft import TaskType, LoftQConfig, LoraConfig, get_peft_model, AutoPeftModelForSequenceClassification
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import roc_auc_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer, TrainingArguments, EvalPrediction

if socket.gethostname() == "ZAEK":
    if not sys.platform.startswith("linux"):
        basepath = Path("C:\\Users\\Ricardo\\repos\\text-generation-webui\\models")
    else:
        basepath = Path("/mnt/c/Users/Ricardo/repos/text-generation-webui/models")
else:
    basepath = Path("/data/tmp/models")

MAX_LINES = 20

In [4]:
model = AutoPeftModelForSequenceClassification.from_pretrained("model", num_labels=MAX_LINES, torch_dtype=torch.bfloat16, problem_type="multi_label_classification", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(basepath / "gemma-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of GemmaForSequenceClassification were not initialized from the model checkpoint at C:\Users\Ricardo\repos\text-generation-webui\models\gemma-2b-it and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
def predict(predictions, threshold=0.5):
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    return y_pred[0].tolist()

text = """sel(X) :- e(X,E).
1 <= { sel(X): s(X) } <= k.
inter(X,Y) :- e(X,E); e(Y,E); X != Y.
 :- inter(X,Y); sel(X); sel(Y)."""

inputs = tokenizer(text,  return_tensors="pt").to("cuda")
output = model(**inputs)
predict(output[0].cpu())

[1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]