# Automatic predictor

In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

## Dataset

In [2]:
class MetricsCorrelationDataset(Dataset):

    def __init__(self, texts, summaries, labels, tokenizer, max_length):
        self.texts = texts
        self.summaries = summaries
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        summary = self.summaries[idx]
        label = self.labels[idx]
        text_encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        summary_encoding = self.tokenizer(summary, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        encoding = self.tokenizer(text, summary, truncation='only_first')
        ans = {
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'label': torch.tensor(label)
        }
        
        return ans

## Model

## Data

In [3]:
import pandas as pd
import numpy as np

index = "Ind"
title = "title"
article = "text"
ground_truth = "summary"

files = [
    "mbart_predictions.txt",
    "mt5_predictions.txt",
    "summarunner_predictions.txt",
    "llama_7b_predictions.csv",
    "starling_predictions.csv",
    "yagpt_predictions.csv",
    "yagpt3_predictions.csv"
]
summaries_fields = []
human_metrics = [
    "Актуальность",
    "Последовательность",
    "Беглость",
    "Согласованность",
    "Комментарий"
]

for file in files:
    model_name = file.split(".")[0]
    summaries_fields.append(model_name)
    for metric in human_metrics:
        summaries_fields.append(f"{model_name}_{metric}")

summaries_fields_types = {field : ('Float64' if field.split("_")[-1] != human_metrics[-1] and field + ".txt" not in files and field + ".csv" not in files else str) for field in summaries_fields}
print(summaries_fields_types)
summaries_fields_types[index] = 'Int64'
summaries_fields_types[article] = summaries_fields_types[ground_truth] = str

metrics_data = pd.read_csv("metrics_data.csv", dtype=summaries_fields_types)
expert_data = pd.read_csv("compiled_expert_data.csv", dtype=summaries_fields_types)
# data = pd.read_csv("export_data.csv")

{'mbart_predictions': <class 'str'>, 'mbart_predictions_Актуальность': 'Float64', 'mbart_predictions_Последовательность': 'Float64', 'mbart_predictions_Беглость': 'Float64', 'mbart_predictions_Согласованность': 'Float64', 'mbart_predictions_Комментарий': <class 'str'>, 'mt5_predictions': <class 'str'>, 'mt5_predictions_Актуальность': 'Float64', 'mt5_predictions_Последовательность': 'Float64', 'mt5_predictions_Беглость': 'Float64', 'mt5_predictions_Согласованность': 'Float64', 'mt5_predictions_Комментарий': <class 'str'>, 'summarunner_predictions': <class 'str'>, 'summarunner_predictions_Актуальность': 'Float64', 'summarunner_predictions_Последовательность': 'Float64', 'summarunner_predictions_Беглость': 'Float64', 'summarunner_predictions_Согласованность': 'Float64', 'summarunner_predictions_Комментарий': <class 'str'>, 'llama_7b_predictions': <class 'str'>, 'llama_7b_predictions_Актуальность': 'Float64', 'llama_7b_predictions_Последовательность': 'Float64', 'llama_7b_predictions_Бегло

In [4]:
import math

human_metrics = human_metrics[:-1]

models = {
    "mbart_predictions",
    "mt5_predictions",
    "summarunner_predictions",
    "llama_7b_predictions",
    "starling_predictions",
    "yagpt_predictions",
    "yagpt3_predictions"
}
metrics = {
    "bleu",
    "rouge1",
    "meteor",
    "bertscore_f1"
}

texts = []
summaries = []
labels = {}
human_scores = []
auto_scores = {"bleu": [], "bertscore": [], "rouge": [], "meteor": []}
totals = {"bleu": 0, "bertscore": 0, "rouge": 0, "meteor": 0}
deviations = {"bleu": [], "bertscore": [], "rouge": [], "meteor": []}
orig_scores = {}


for model in models:
    for (mteric_index, metric_row), (expert_index, expert_row)  in zip(metrics_data.iterrows(), expert_data.iterrows()):
        # if any([row[f"{model}_{metric}"] is None or math.isnan(row[f"{model}_{metric}"]) for metric in metrics]) or row[model] is None or type(row[model]) != str:
        #     continue
        texts.append(metric_row["summary"])
        summaries.append(metric_row[model])
        human_scores.append(np.mean([expert_row[f"{model}_{metric}"] for metric in human_metrics]) / 5)
        # print(model, human_scores[-1])

        auto_scores["bleu"].append(metric_row[f"{model}_bleu"])
        auto_scores["rouge"].append(metric_row[f"{model}_rouge1"])
        auto_scores["meteor"].append(metric_row[f"{model}_meteor"])
        auto_scores["bertscore"].append(metric_row[f"{model}_bertscore_f1"])


mean_human = np.mean(human_scores)
mean_bleu = np.mean(auto_scores["bleu"])
mean_rouge = np.mean(auto_scores["rouge"])
mean_meteor = np.mean(auto_scores["meteor"])
mean_bertscore = np.mean(auto_scores["bertscore"])

orig_scores["bleu"] = auto_scores["bleu"].copy()
orig_scores["rouge"] = auto_scores["rouge"].copy()
orig_scores["meteor"] = auto_scores["meteor"].copy()
orig_scores["bertscore"] = auto_scores["bertscore"].copy()

for i in range(len(texts)):
    human_scores[i] -= mean_human
    auto_scores["bleu"][i] -= mean_bleu
    auto_scores["rouge"][i] -= mean_rouge
    auto_scores["meteor"][i] -= mean_meteor
    auto_scores["bertscore"][i] -= mean_bertscore

    deviations["bleu"].append(abs(auto_scores["bleu"][i] - human_scores[i]))
    deviations["rouge"].append(abs(auto_scores["rouge"][i] - human_scores[i]))
    deviations["meteor"].append(abs(auto_scores["meteor"][i] - human_scores[i]))
    deviations["bertscore"].append(abs(auto_scores["bertscore"][i] - human_scores[i]))

print("Median devs:")
print("\tBLEU: ", np.median(deviations["bleu"]), np.max(deviations["bleu"]))
print("\tROUGE: ", np.median(deviations["rouge"]), np.max(deviations["rouge"]))
print("\tMETEOR: ", np.median(deviations["meteor"]), np.max(deviations["meteor"]))
print("\tBERTSCORE: ", np.median(deviations["bertscore"]), np.max(deviations["bertscore"]))

labels["bleu"] = [1.0 if deviations["bleu"][i] < np.quantile(deviations["bleu"], 0.5) else 0.0 for i in range(len(texts))]
labels["rouge"] = [1.0 if deviations["rouge"][i] < np.quantile(deviations["rouge"], 0.5) else 0.0 for i in range(len(texts))]
labels["meteor"] = [1.0 if deviations["meteor"][i] < np.quantile(deviations["meteor"], 0.5) else 0.0 for i in range(len(texts))]
labels["bertscore"] = [1.0 if deviations["bertscore"][i] < np.quantile(deviations["bertscore"], 0.5) else 0.0  for i in range(len(texts))]

Median devs:
	BLEU:  0.08652315462924738 0.5889453306820962
	ROUGE:  0.08388135593220336 0.69087546980561
	METEOR:  0.0933422032178387 0.5547894310323825
	BERTSCORE:  0.08324829339981082 0.510270966206278


## Dataset instance

In [5]:
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

def get_dataset(tokenizer, metric_name):
    texts_len = len(texts)
    train = int(texts_len * 0.9)
    test_val = int(texts_len * 0.1)
    max_length = 512
    return {
        "train": MetricsCorrelationDataset(texts=texts[0:train], summaries=summaries[0:train], labels=labels[metric_name][0:train], tokenizer=tokenizer, max_length=max_length),
        "test": MetricsCorrelationDataset(texts=texts[train:texts_len], summaries=summaries[train:texts_len], labels=labels[metric_name][train:texts_len], tokenizer=tokenizer, max_length=max_length),
        #"val": MetricsCorrelationDataset(texts=texts[train + test_val:train + 2 * test_val], summaries=summaries[train + test_val:train + 2 * test_val], labels=labels[metric_name][train + test_val:train + 2 * test_val], tokenizer=tokenizer, max_length=max_length)
    }

## Trainer

In [7]:
from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments(
    output_dir="automatic_predictor",
    learning_rate=3e-4,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=3,
    num_train_epochs=25,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

## Work

In [8]:
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)
from sklearn.preprocessing import label_binarize
import numpy as np

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    probabilities = np.exp(predictions) / np.sum(np.exp(predictions), axis=-1, keepdims=True)
    predictions = torch.tensor([float(round(x)) for x in predictions.flatten()])
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="macro")
    p = precision_score(labels, predictions, average="macro")
    r = recall_score(labels, predictions, average="macro")
    return {"precision": p, "recall": r, "f1": f1, "accuracy": accuracy}

In [9]:
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    PromptEncoderConfig,
)
peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=30, encoder_hidden_size=256)

### Rouge

In [19]:
from transformers import AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split



model = AutoModelForSequenceClassification.from_pretrained("ai-forever/ruRoberta-large", num_labels=1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruRoberta-large")
tokenizer.model_max_length=482

rouge_dataset = get_dataset(tokenizer, "rouge")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=rouge_dataset["train"],
    eval_dataset=rouge_dataset["test"],
    tokenizer=rouge_dataset["train"].tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at ai-forever/ruRoberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 1,672,705 || all params: 357,033,474 || trainable%: 0.4685


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,0.26785,0.445455,0.463235,0.419152,0.471429
2,No log,0.250753,0.526667,0.52451,0.516635,0.528571
3,0.455300,0.291932,0.242857,0.5,0.326923,0.485714
4,0.455300,0.255059,0.509989,0.507353,0.468662,0.5
5,0.338500,0.244385,0.62,0.610294,0.60452,0.614286
6,0.338500,0.253965,0.420398,0.486928,0.357124,0.5
7,0.338500,0.277135,0.5625,0.525327,0.436553,0.514286
8,0.315800,0.244136,0.617452,0.616013,0.613576,0.614286
9,0.315800,0.404505,0.242857,0.5,0.326923,0.485714
10,0.308800,0.305762,0.257143,0.5,0.339623,0.514286


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TypeError: 'method' object is not subscriptable

In [None]:
model

In [None]:
import gc
import torch

model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache() 

### Bertscore

In [21]:
from transformers import AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split



model = AutoModelForSequenceClassification.from_pretrained("ai-forever/ruRoberta-large", num_labels=1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruRoberta-large")
tokenizer.model_max_length=492

rouge_dataset = get_dataset(tokenizer, "bertscore")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=rouge_dataset["train"],
    eval_dataset=rouge_dataset["test"],
    tokenizer=rouge_dataset["train"].tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at ai-forever/ruRoberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 1,350,913 || all params: 356,711,682 || trainable%: 0.3787


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.236685,0.555556
2,0.351800,0.232414,0.555556
3,0.286600,0.242504,0.555556
4,0.286600,0.257218,0.555556
5,0.289000,0.254714,0.555556
6,0.303900,0.259226,0.555556
7,0.293200,0.265025,0.555556
8,0.293200,0.262129,0.555556
9,0.270000,0.273124,0.555556
10,0.290800,0.278212,0.555556




KeyboardInterrupt: 

In [None]:
import gc
import torch

model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache() 

### METEOR

In [None]:
from transformers import AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split



model = AutoModelForSequenceClassification.from_pretrained("ai-forever/ruRoberta-large", num_labels=1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruRoberta-large")
tokenizer.model_max_length=492

rouge_dataset = get_dataset(tokenizer, "meteor")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=rouge_dataset["train"],
    eval_dataset=rouge_dataset["test"],
    tokenizer=rouge_dataset["train"].tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
import gc
import torch

model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache() 

# Advanced

In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

In [3]:
import pandas as pd
import numpy as np

index = "Ind"
title = "title"
article = "text"
ground_truth = "summary"

files = [
    "mbart_predictions.txt",
    "mt5_predictions.txt",
    "summarunner_predictions.txt",
    "llama_7b_predictions.csv",
    "starling_predictions.csv",
    "yagpt_predictions.csv",
    "yagpt3_predictions.csv"
]
summaries_fields = []
human_metrics = [
    "Актуальность",
    "Последовательность",
    "Беглость",
    "Согласованность",
    "Комментарий"
]

for file in files:
    model_name = file.split(".")[0]
    summaries_fields.append(model_name)
    for metric in human_metrics:
        summaries_fields.append(f"{model_name}_{metric}")

summaries_fields_types = {field : ('Float64' if field.split("_")[-1] != human_metrics[-1] and field + ".txt" not in files and field + ".csv" not in files else str) for field in summaries_fields}
print(summaries_fields_types)
summaries_fields_types[index] = 'Int64'
summaries_fields_types[article] = summaries_fields_types[ground_truth] = str

metrics_data = pd.read_csv("metrics_data.csv", dtype=summaries_fields_types)
expert_data = pd.read_csv("compiled_expert_data.csv", dtype=summaries_fields_types)
# data = pd.read_csv("export_data.csv")

{'mbart_predictions': <class 'str'>, 'mbart_predictions_Актуальность': 'Float64', 'mbart_predictions_Последовательность': 'Float64', 'mbart_predictions_Беглость': 'Float64', 'mbart_predictions_Согласованность': 'Float64', 'mbart_predictions_Комментарий': <class 'str'>, 'mt5_predictions': <class 'str'>, 'mt5_predictions_Актуальность': 'Float64', 'mt5_predictions_Последовательность': 'Float64', 'mt5_predictions_Беглость': 'Float64', 'mt5_predictions_Согласованность': 'Float64', 'mt5_predictions_Комментарий': <class 'str'>, 'summarunner_predictions': <class 'str'>, 'summarunner_predictions_Актуальность': 'Float64', 'summarunner_predictions_Последовательность': 'Float64', 'summarunner_predictions_Беглость': 'Float64', 'summarunner_predictions_Согласованность': 'Float64', 'summarunner_predictions_Комментарий': <class 'str'>, 'llama_7b_predictions': <class 'str'>, 'llama_7b_predictions_Актуальность': 'Float64', 'llama_7b_predictions_Последовательность': 'Float64', 'llama_7b_predictions_Бегло

In [4]:
import math

human_metrics = human_metrics[:-1]

models = {
    "mbart_predictions",
    "mt5_predictions",
    "summarunner_predictions",
    "llama_7b_predictions",
    "starling_predictions",
    "yagpt_predictions",
    "yagpt3_predictions"
}
metrics = {
    "bleu",
    "rouge1",
    "meteor",
    "bertscore_f1"
}

texts = []
summaries = []
labels = {}
human_scores = []
auto_scores = {"bleu": [], "bertscore": [], "rouge": [], "meteor": []}
totals = {"bleu": 0, "bertscore": 0, "rouge": 0, "meteor": 0}
deviations = {"bleu": [], "bertscore": [], "rouge": [], "meteor": []}
orig_scores = {}


for model in models:
    for (mteric_index, metric_row), (expert_index, expert_row)  in zip(metrics_data.iterrows(), expert_data.iterrows()):
        # if any([row[f"{model}_{metric}"] is None or math.isnan(row[f"{model}_{metric}"]) for metric in metrics]) or row[model] is None or type(row[model]) != str:
        #     continue
        texts.append(metric_row["summary"])
        summaries.append(metric_row[model])
        human_scores.append(np.mean([expert_row[f"{model}_{metric}"] for metric in human_metrics]) / 5)
        # print(model, human_scores[-1])

        auto_scores["bleu"].append(metric_row[f"{model}_bleu"])
        auto_scores["rouge"].append(metric_row[f"{model}_rouge1"])
        auto_scores["meteor"].append(metric_row[f"{model}_meteor"])
        auto_scores["bertscore"].append(metric_row[f"{model}_bertscore_f1"])


mean_human = np.mean(human_scores)
mean_bleu = np.mean(auto_scores["bleu"])
mean_rouge = np.mean(auto_scores["rouge"])
mean_meteor = np.mean(auto_scores["meteor"])
mean_bertscore = np.mean(auto_scores["bertscore"])

orig_scores["bleu"] = auto_scores["bleu"].copy()
orig_scores["rouge"] = auto_scores["rouge"].copy()
orig_scores["meteor"] = auto_scores["meteor"].copy()
orig_scores["bertscore"] = auto_scores["bertscore"].copy()

for i in range(len(texts)):
    human_scores[i] -= mean_human
    auto_scores["bleu"][i] -= mean_bleu
    auto_scores["rouge"][i] -= mean_rouge
    auto_scores["meteor"][i] -= mean_meteor
    auto_scores["bertscore"][i] -= mean_bertscore

    deviations["bleu"].append(abs(auto_scores["bleu"][i] - human_scores[i]))
    deviations["rouge"].append(abs(auto_scores["rouge"][i] - human_scores[i]))
    deviations["meteor"].append(abs(auto_scores["meteor"][i] - human_scores[i]))
    deviations["bertscore"].append(abs(auto_scores["bertscore"][i] - human_scores[i]))

print("Median devs:")
print("\tBLEU: ", np.median(deviations["bleu"]), np.max(deviations["bleu"]))
print("\tROUGE: ", np.median(deviations["rouge"]), np.max(deviations["rouge"]))
print("\tMETEOR: ", np.median(deviations["meteor"]), np.max(deviations["meteor"]))
print("\tBERTSCORE: ", np.median(deviations["bertscore"]), np.max(deviations["bertscore"]))

labels["bleu"] = [1.0 if deviations["bleu"][i] < np.quantile(deviations["bleu"], 0.5) else 0.0 for i in range(len(texts))]
labels["rouge"] = [1.0 if deviations["rouge"][i] < np.quantile(deviations["rouge"], 0.5) else 0.0 for i in range(len(texts))]
labels["meteor"] = [1.0 if deviations["meteor"][i] < np.quantile(deviations["meteor"], 0.5) else 0.0 for i in range(len(texts))]
labels["bertscore"] = [1.0 if deviations["bertscore"][i] < np.quantile(deviations["bertscore"], 0.5) else 0.0  for i in range(len(texts))]

Median devs:
	BLEU:  0.08652315462924738 0.5889453306820962
	ROUGE:  0.08388135593220336 0.69087546980561
	METEOR:  0.0933422032178387 0.5547894310323825
	BERTSCORE:  0.08324829339981082 0.510270966206278


In [5]:
from Levenshtein import ratio

class MetricsCorrelationAdvancedDataset(Dataset):

    def __init__(self, texts, summaries, labels, scores, tokenizer, max_length):
        self.texts = texts
        self.summaries = summaries
        self.scores = scores
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        summary = self.summaries[idx]
        label = self.labels[idx]
        text_encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        summary_encoding = self.tokenizer(summary, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        encoding = self.tokenizer(text, summary, truncation='only_first')
        # print([self.scores[idx], len(text) / len(summary)])
        ans = {
            # 'text_input_ids': text_encoding['input_ids'].flatten(), 
            # 'text_attention_mask': text_encoding['attention_mask'].flatten(), 
            # 'summary_input_ids': summary_encoding['input_ids'].flatten(),
            # 'summary_attention_mask': summary_encoding['attention_mask'].flatten(),
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'labels': torch.tensor(label),
            'extra_data': [self.scores[idx], len(text) / len(summary), len(text.split(".")) / len(summary.split(".")), ratio(text, summary)]
        }
        
        return ans

In [6]:
from transformers import TrainingArguments
from transformers import Trainer

training_args = TrainingArguments(
    output_dir="automatic_predictor",
    learning_rate=5e-3,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    num_train_epochs=25,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

In [7]:
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer

def get_dataset(tokenizer, metric_name):
    texts_len = len(texts)
    train = int(texts_len * 0.9)
    test_val = int(texts_len * 0.1)
    max_length = 1020
    return {
        "train": MetricsCorrelationAdvancedDataset(texts=texts[0:train], summaries=summaries[0:train], labels=labels[metric_name][0:train], scores=orig_scores[metric_name][0:train], tokenizer=tokenizer, max_length=max_length),
        "test": MetricsCorrelationAdvancedDataset(texts=texts[train:texts_len], summaries=summaries[train:texts_len], labels=labels[metric_name][train:texts_len], scores=orig_scores[metric_name][train:texts_len], tokenizer=tokenizer, max_length=max_length),
        #"val": MetricsCorrelationDataset(texts=texts[train + test_val:train + 2 * test_val], summaries=summaries[train + test_val:train + 2 * test_val], labels=labels[metric_name][train + test_val:train + 2 * test_val], tokenizer=tokenizer, max_length=max_length)
    }

In [8]:
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)
from sklearn.preprocessing import label_binarize
import numpy as np

def compute_metrics(eval_pred):
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    probabilities = np.exp(predictions) / np.sum(np.exp(predictions), axis=-1, keepdims=True)
    predictions = torch.tensor([float(round(x)) for x in predictions.flatten()])
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="macro")
    p = precision_score(labels, predictions, average="macro")
    r = recall_score(labels, predictions, average="macro")
    return {"precision": p, "recall": r, "f1": f1, "accuracy": accuracy}

In [9]:
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    PeftType,
    PromptEncoderConfig,
)
peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=30, encoder_hidden_size=256)

In [10]:
import torch
from torch import nn
from transformers import AutoConfig, RobertaModel, RobertaForSequenceClassification, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional, Union, Tuple

class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config, num_extra_dims):
        super().__init__()
        total_dims = config.hidden_size+num_extra_dims
        self.dense = nn.Linear(total_dims, total_dims)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(total_dims, config.num_labels)

    def forward(self, features, **kwargs):
        x = self.dropout(features)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class CustomSequenceClassification(RobertaForSequenceClassification):

    def __init__(self, config, num_extra_dims):
        print(config)
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        # self.bert_model =  RobertaModel.from_pretrained("ai-forever/ruRoberta-large")
        # self.bert_model = RobertaMode(config)
        self.classifier = ClassificationHead(config, num_extra_dims)

        self.post_init()

    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        extra_data: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # outputs = self.bert_model(
        #     input_ids,
        #     attention_mask=attention_mask,
        #     token_type_ids=token_type_ids,
        #     position_ids=position_ids,
        #     head_mask=head_mask,
        #     inputs_embeds=inputs_embeds,
        #     output_attentions=output_attentions,
        #     output_hidden_states=output_hidden_states,
        #     return_dict=return_dict,
        # )
        outputs = super().base_model(#forward(
            input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds#, labels=labels
        )

        sequence_output = outputs[0]

        cls_embedding = sequence_output[:, 0, :]

        output = torch.cat((cls_embedding, extra_data), dim=-1)

        logits = self.classifier(output)

        loss = None

        if labels is not None:
            labels = labels
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.squeeze(), labels)
        
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [13]:
from transformers import AutoModelForSequenceClassification, AutoModel
from sklearn.model_selection import train_test_split


model = CustomSequenceClassification.from_pretrained("ai-forever/ruRoberta-large", num_labels=1, num_extra_dims=4)
#model = CustomSequenceClassification(bert_model, )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruRoberta-large")
tokenizer.model_max_length=482

rouge_dataset = get_dataset(tokenizer, "rouge")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=rouge_dataset["train"],
    eval_dataset=rouge_dataset["test"],
    tokenizer=rouge_dataset["train"].tokenizer,
    # data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of CustomSequenceClassification were not initialized from the model checkpoint at ai-forever/ruRoberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RobertaConfig {
  "_name_or_path": "ai-forever/ruRoberta-large",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 1,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.41.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

trainable params: 1,680,921 || all params: 357,049,906 || trainable%: 0.4708
torch.Size([6])
torch.Size([6, 1])


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,No log,0.782872,0.789855,0.516667,0.399231,0.585714
2,No log,0.683978,0.214286,0.5,0.3,0.428571
3,No log,1.087464,0.0,0.0,0.0,0.0
4,No log,0.781751,0.0,0.0,0.0,0.0
5,0.936400,1.098836,0.0,0.0,0.0,0.0
6,0.936400,0.748996,0.285714,0.5,0.363636,0.571429
7,0.936400,0.820825,0.285714,0.5,0.363636,0.571429
8,0.936400,1.112153,0.0,0.0,0.0,0.0
9,0.936400,0.866638,0.0,0.0,0.0,0.0
10,0.936400,1.178655,0.0,0.0,0.0,0.0


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc



torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc



torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torch.Size([6, 1])
torch.Size([6])
torc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TypeError: 'method' object is not subscriptable