#TODO
1. Costruisci un classifier in pytorch (o in tensorflow se riesci), che prende in input un abstract, espresso tramite word embedding, e ritorna, per ogni parola nell'abstract, una label che indica il tipo di classe a cui quella appartiene (BIO tagging).
 * **PROBLEMA**: capire come funziona internamente il "Classify and Count" di QuaPy, perche' nel tuo caso devi creare una classe simile a  "Classify and Count" che pero', nella funzione "aggregate", riduce **"B-Claim, I-Claim, B-Evidence, I-Evidence e Outside"** in sole 3 classi **"Claim, Evidence e Outside"**.
2. Utilizza dei metodi presenti nel paper, che fanno *quantization* senza fare classification+counting
3. Paragona i risultati di questi due approcci, su diverse metriche.

# IMPORTANTE per QuaPy
1. Vedi come fare Model Selection usando la GridSearch di QuaPy
2. Vedi come testare i tuoi modelli, attraverso l'uso di  *natural prevalence protocol (NPP)* e  *artificial prevalence protocol (APP)*
3. Vedi come fare dei grafici, da cui e' possibile discutere i risultati
4. Vedi come poter usare QuaNet

`QuaPy provides an implementation of QuaNet, a deep-learning-based method for performing quantification on samples of textual documents, presented in [8].8 QuaNet processes as input a list of document embeddings (see below),
one for each unlabelled document along with their posterior probabilities generated by a probabilistic classifier. The
list is processed by a bidirectional LSTM that generates a sample embedding (i.e., a dense representation of the entire
sample), which is then concatenated with a vector of class prevalence estimates produced by an ensemble of simpler
quantification methods (CC, ACC, PCC, PACC, SLD). This vector is then transformed by a set of feed-forward layers,
followed by ReLU activations and dropout, to compute the final estimations.`

QuaNet thus requires a probabilistic classifier that can provide embedded representations of the inputs. QuaPy offers
a basic implementation of such a classifier, based on convolutional neural networks, that returns its next-to-last representation as the document embedding. The following is a working example showing how to index a textual dataset
(see Section 3) and how to instantiate QuaNet:
```python
1 import quapy as qp
2 from quapy.method.meta import QuaNet
3 from classification.neural import NeuralClassifierTrainer, CNNnet
4
5 qp.environ['SAMPLE_SIZE'] = 500
6
7 # load the kindle dataset as plain text, and convert words
8 # to numerical indexes
9 dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
10 qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
11
12 cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
13 learner = NeuralClassifierTrainer(cnn, device='cuda')
14 model = QuaNet(learner, sample_size=500, device='cuda')
```
QuaNet paper link : Esuli, A., Moreo, A., & Sebastiani, F. (2018, October). A recurrent neural network for sentiment quantification. In Proceedings of the 27th ACM International Conference on Information and Knowledge Management (pp. 1775-1778).
# IMPORTANTE per Preprocessing
1. Vedi quale Word Embedding utilizzare
2. Vedi se e come utilizzare Bert per generare degli abstract embeddings
3. Vedi se ha senso usare altri tipi, come SciBert (che e' stato allenato su papers scientifici)


PROBLEMA IMPORTANTISSIMO

In [1]:
!pip install transformers pytorch-crf==0.7.2 torchmetrics

Collecting pytorch-crf==0.7.2
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.5.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Downloading torchmetrics-1.5.0-py3-none-any.whl (890 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.5/890.5 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Installing collected packages: pytorch-crf, lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.8 pytorch-crf-0.7.2 torchmetrics-1.5.0


In [2]:
import csv
import sys
import os
import collections
import random
from tqdm import trange,tqdm

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, Subset
from transformers import BasicTokenizer, BertPreTrainedModel, BertModel, BertConfig,  BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
import numpy as np
from torch import nn
from torchcrf import CRF
import torchmetrics



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
data_dir = '/content/drive/My Drive/NLPProjectWork/data/neoplasm'
# Ensure the directory exists
if os.path.exists(data_dir):
    os.chdir(data_dir)
    print(f"Current directory: {os.getcwd()}")
else:
    print("Directory not found.")

Current directory: /content/drive/My Drive/NLPProjectWork/data/neoplasm


In [5]:
def set_seed(seed, n_gpu):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

In [6]:
n_gpu = torch.cuda.device_count()
set_seed(seed = 2, n_gpu = n_gpu)

# PREPROCESSING


In [7]:
class InputExample(object):
    def __init__(self, guid, text, labels=None):
        self.guid = guid
        self.text = text
        self.labels = labels


class InputFeatures(object):
    def __init__(self, input_ids, input_mask, segment_ids, label_ids, label_proba):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.label_proba = label_proba

In [8]:
class DataProcessor(object):
    def __init__(self):
        self.labels = ["X", "B-Claim", "I-Claim", "B-Premise", "I-Premise", 'O']
        self.label_map = self._create_label_map()
        self.replace_labels = {
            'B-MajorClaim': 'B-Claim',
            'I-MajorClaim': 'I-Claim',
        }

    def _create_label_map(self):
        label_map = collections.OrderedDict()
        for i, label in enumerate(self.labels):
            label_map[label] = i
        return label_map

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, "train.conll"), replace=self.replace_labels), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, "dev.conll"), replace=self.replace_labels), "dev")

    def get_test_examples(self, data_dir, setname="test.conll"):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, setname), replace=self.replace_labels), "test")

    def get_labels(self):
        """ See base class."""
        return self.labels

    def convert_labels_to_ids(self, labels):
        idx_list = []
        for label in labels:
            idx_list.append(self.label_map[label])
        return idx_list

    def convert_ids_to_labels(self, idx_list):
        labels_list = []
        for idx in idx_list:
            labels_list.append([key for key in self.label_map.keys() if self.label_map[key] == idx][0])
        return labels_list

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, str(i))
            text = line[0]
            labels = line[-1]
            examples.append(
                InputExample(guid=guid, text=text, labels=labels))
        return examples

    def convert_examples_to_features(self, examples, max_seq_length, tokenizer,
                                     cls_token='[CLS]',
                                     sep_token='[SEP]'):
        """Loads a data file into a list of `InputBatch`s."""

        features = []
        for (ex_index, example) in enumerate(examples):
            tokens_a, labels = tokenizer.tokenize_with_label_extension(example.text, example.labels,
                                                                       copy_previous_label=True)

            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:(max_seq_length - 2)]
                labels = labels[:(max_seq_length - 2)]
            labels = ["X"] + labels + ["X"]

            tokens = [cls_token] + tokens_a + [sep_token]
            segment_ids = [0] * len(tokens)
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)
            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding
            label_ids = self.convert_labels_to_ids(labels)
            # Se il mask e' None allora non fare nulla
            # Altrimenti maska sta label
            label_proba = self.compute_label_probadist(label_ids)
            # leva il primo e l'ultimo elemento
            # assert(non contiene 0 all'interno)
            # nel caso delle predizioni se contiene 0 che devo fare? Direi di skippare tutti gli zeri dal calcolo della proba di dist

            label_ids += padding

            assert len(label_ids) == max_seq_length
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_ids=label_ids,
                              label_proba=label_proba))
        return features

    def compute_label_probadist(self, label_ids, mask=None):
        label_ids_np = np.array(label_ids)
        if mask is not None:
            mask_np = mask.cpu().detach().numpy()
            label_ids_np = label_ids_np[mask_np == 1]
        label_ids_np = label_ids_np[1:-1]
        label_ids_np = label_ids_np[label_ids_np != 0]
        if len(label_ids_np) == 0:
            return [0.0, 0.0, 0.0]
        assert (0 not in label_ids_np)
        map = {self.label_map["B-Claim"]: 0,
               self.label_map["I-Claim"]: 0,
               self.label_map["B-Premise"]: 1,
               self.label_map["I-Premise"]: 1,
               self.label_map["O"]: 2}
        remapped_label_ids = [map[label] for label in label_ids_np]
        counts = np.bincount(remapped_label_ids, minlength=3)
        probabilities = counts / len(remapped_label_ids)
        return probabilities

    @classmethod
    def features_to_dataset(cls, feature_list):
        all_input_ids = torch.tensor([f.input_ids for f in feature_list], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in feature_list], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in feature_list], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_ids for f in feature_list], dtype=torch.long)
        all_label_probas = torch.tensor([np.array(f.label_proba) for f in feature_list])
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_label_probas)
        return dataset

    @classmethod
    def read_conll(cls, input_file, token_number=0, token_column=1, label_column=4, replace=None):
        """Reads a conll type file."""
        with open(input_file, "r", encoding='utf-8') as f:
            lines = f.readlines()
            sentences = []
            tokenizer = BasicTokenizer()
            sent_tokens = []
            sent_labels = []

            for idx, line in enumerate(lines):

                line = line.split('\t')
                # Skip the lines in which there is /n
                if len(line) < 2:
                    continue

                # Controllare se sono all'ID numero 1, assicurandomi nel mentre che NON sia la prima iterazione
                if idx != 0 and int(line[token_number]) == 0:
                    assert len(sent_tokens) == len(sent_labels)
                    if replace:
                        sent_labels = [replace[label] if label in replace.keys() else label for label in sent_labels]
                    sentences.append([' '.join(sent_tokens), sent_labels])
                    sent_tokens = []
                    sent_labels = []

                token = line[token_column]
                label = line[label_column].replace('\n', '')
                tokenized = tokenizer.tokenize(token)

                if len(tokenized) > 1:
                    for i in range(len(tokenized)):
                        sent_tokens.append(tokenized[i])
                        sent_labels.append(label)
                else:
                    sent_tokens.append(tokenized[0])
                    sent_labels.append(label)
            if sent_tokens != []:
                assert len(sent_tokens) == len(sent_labels)
                if replace:
                    sent_labels = [replace[label] if label in replace.keys() else label for label in sent_labels]
                sentences.append([' '.join(sent_tokens), sent_labels])
        return sentences

    def load_examples(self, data_dir, max_seq_length, tokenizer, evaluate=False, isval=False):
        if evaluate:
            examples = self.get_test_examples(data_dir)
        elif isval:
            examples = self.get_dev_examples(data_dir)
        else:
            examples = self.get_train_examples(data_dir)
        features = self.convert_examples_to_features(examples, max_seq_length=max_seq_length, tokenizer=tokenizer)
        dataset = self.features_to_dataset(features)
        return dataset

In [9]:
class ExtendedBertTokenizer:
    """Extended tokenizer that wraps a base tokenizer."""

    def __init__(self, base_tokenizer):
        self.tokenizer = base_tokenizer

    def tokenize_with_label_extension(self, text, labels, copy_previous_label=False, extension_label='X'):
        tok_text = self.tokenizer.tokenize(text)
        for i in range(len(tok_text)):
            if '##' in tok_text[i]:
                if copy_previous_label:
                    labels.insert(i, labels[i - 1])
                else:
                    labels.insert(i, extension_label)
        return tok_text, labels

    def convert_tokens_to_ids(self, tokens):
      return self.tokenizer.convert_tokens_to_ids(tokens)

# MODELS

In [10]:
class BertForSequenceTagging(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.crf = CRF(config.num_labels, batch_first=True)
        self.classifier = nn.Linear(2 * config.hidden_size, config.num_labels)
        self.custom_init_weights()
        self.name = "seq_tag_model"

    def custom_init_weights(self):
        # Initialize weights of GRU
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=0.01)
            elif 'bias' in name:
                nn.init.zeros_(param)  # Initialize biases to zero

        # Initialize weights of CRF transitions
        nn.init.normal_(self.crf.start_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.end_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.transitions, mean=0, std=0.1)

        # Initialize weights of the classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs[0]

        print(f"Sequence output (BERT): {sequence_output.min()}, {sequence_output.max()}")
        rnn_out, _ = self.rnn(sequence_output)
        emissions = self.classifier(rnn_out)

        mask = attention_mask.bool()
        log_likelihood = self.crf(emissions, labels, mask=mask)
        path = self.crf.decode(emissions)
        path = torch.LongTensor(path)
        return (-1 * log_likelihood, emissions, path, None)

In [11]:
class ClassificationAndCounting:
    def __init__(self, learner, processor: DataProcessor):
        self.learner = learner
        self.processor = processor
        self.name = "bert_classify_and_count"

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None
    ):

        loss, emissions, path, _ = self.learner(input_ids, attention_mask, token_type_ids, labels, labels_proba)

        proba_dists = self.aggregate(path, attention_mask)
        return loss, emissions, path, proba_dists

    def aggregate(self, predictions, attention_mask):
        all_probabilities = []
        for sequence, mask in zip(predictions, attention_mask):
            label_proba = self.processor.compute_label_probadist(sequence, mask=mask)
            all_probabilities.append(label_proba)
        return np.array(all_probabilities)

    # Wrapper methods mimicking PyTorch model behavior
    def train(self, mode=True):
        self.learner.train(mode)

    def eval(self):
        self.learner.eval()

    def state_dict(self, *args, **kwargs):
        return self.learner.state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        return self.learner.load_state_dict(*args, **kwargs)

    def to(self, *args, **kwargs):
        self.learner.to(*args, **kwargs)

    def parameters(self):
        return self.learner.parameters()

    def named_parameters(self):
        return self.learner.named_parameters()

    def zero_grad(self):
        self.learner.zero_grad()

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

In [12]:
class BertForLabelDistribution(BertPreTrainedModel):
    def __init__(self, config, loss_fn, loss_fn_name):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(2 * config.hidden_size, config.num_labels)
        self.softmax = nn.Softmax(dim=-1)
        self.init_weights()
        self.loss_fn = loss_fn
        self.name = f"bert_quantify_{loss_fn_name}"
        self.custom_init_weights()

    def custom_init_weights(self):
        # Initialize weights of GRU
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=0.01)
            elif 'bias' in name:
                nn.init.zeros_(param)  # Initialize biases to zero

        # Initialize weights of the classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs[0]  # [batch_size, seq_len, hidden_size]
        rnn_out, _ = self.rnn(sequence_output)  # [batch_size, seq_len, 2*hidden_size]
        emissions = self.classifier(rnn_out)  # [batch_size, seq_len, num_labels]
        cls_emissions = emissions[:, 0, :]  # [batch_size, num_labels]
        # posos calcolare la loss e dopo la softmax
        cls_probs = self.softmax(cls_emissions)  # [batch_size, num_labels]
        loss = self.loss_fn(cls_probs, labels_proba)
        return loss, emissions, None, cls_probs

# TRAIN/EVALUATE

In [13]:
def get_quantization_preds_labels(inputs, outputs):
    _, _, _, pred_proba = outputs
    pred_proba = torch.tensor(pred_proba)
    labels_proba = inputs["labels_proba"]
    return torch.tensor(pred_proba).to(device),labels_proba

def get_classf_preds_labels(inputs, outputs):
    loss, emissions, path, _ = outputs
    batch_logits = path.detach().cpu().numpy().flatten()
    batch_labels = inputs["labels"].detach().cpu().numpy().flatten()
    attention_mask = inputs["attention_mask"].detach().cpu().numpy().flatten()
    valid_batch_logits = batch_logits[attention_mask == 1]
    valid_batch_labels = batch_labels[attention_mask == 1]
    return torch.tensor(valid_batch_logits), torch.tensor(valid_batch_labels)

In [14]:
def mae_loss(pred_probas, label_probas):
    loss = nn.L1Loss(reduction="mean")
    return loss(pred_probas, label_probas)

def kldiv_loss(pred_probas, label_probas):
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    pred_probas_log = torch.log(pred_probas)
    output = kl_loss(pred_probas_log, label_probas)
    return output

In [15]:
def process_batch(batch, device, model, gen_preds_labels_fn, eval=False):
    model.eval() if eval else model.train()
    batch = tuple(t.to(device) for t in batch)
    inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": batch[3],
              "labels_proba": batch[4]}
    with torch.no_grad() if eval else torch.enable_grad():
        outputs = model(**inputs)
    preds, labels = gen_preds_labels_fn(inputs, outputs)
    return outputs, preds, labels

In [16]:
def evaluate(device, model, eval_dataset, eval_batch_size, metrics, gen_preds_labels_fn):
    eval_dataloader = DataLoader(eval_dataset, batch_size=eval_batch_size)
    epoch_loss_sum = 0.0
    for metric in metrics:
        metric.reset()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            outputs, preds, labels = process_batch(batch, device, model, gen_preds_labels_fn, eval=True)
            loss = outputs[0]
            epoch_loss_sum += loss.item()
            for metric in metrics:
                if metric.__class__.__name__ == "KLDivergence":
                      labels = labels.clamp(min=1e-9)
                      preds = preds.clamp(min=1e-9)
                metric.update(preds, labels)
    eval_loss = epoch_loss_sum / len(eval_dataloader)
    eval_metrics = {metric.__class__.__name__: metric.compute().item() for metric in metrics}
    return eval_loss, eval_metrics

In [37]:
def update_metrics(history, metrics, metric_type, epoch):
    for metric_name, metric_value in metrics.items():
        if epoch == 0:
            history[metric_type][metric_name] = [metric_value]
        else:
            history[metric_type][metric_name].append(metric_value)
        print(f"{metric_name}: {metric_value:.4f}")
def train(
        device,
        train_dataset,
        model,
        eval_dataset,
        generate_preds_labels_fn,
        metrics,
        num_train_epochs=30,
        train_batch_size=5,
        eval_batch_size=32,
        weight_decay=0.3,
        learning_rate=2e-5,
        adam_epsilon=1e-8,
        warmup_steps=5,
        max_grad_norm = 3.0
):
    history = {
        'train_loss': [],
        'train_metrics': {},
        'eval_loss': [],
        'eval_metrics': {}
    }
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size)
    num_steps_per_epoch = len(train_dataloader)
    t_total = num_steps_per_epoch * num_train_epochs
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
    )
    epochs_trained = 0

    train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch")
    print(f"TRAINING MODEL {model.name}")
    for epoch in train_iterator:
        epoch_loss_sum = 0.0
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            optimizer.zero_grad()
            outputs, preds, labels = process_batch(batch, device, model, generate_preds_labels_fn, eval=False)
            loss = outputs[0]
            print(f"SINGLE LOSS: {loss}")
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            epoch_loss_sum += loss.item()
            for metric in metrics:
                if metric.__class__.__name__ == "KLDivergence":
                      labels = labels.clamp(min=1e-9)
                      preds = preds.clamp(min=1e-9)
                metric.update(preds, labels)
            optimizer.step()
            scheduler.step()

        print(f"CUMULATIVE LOSSES: {epoch_loss_sum}")
        epoch_loss = epoch_loss_sum / num_steps_per_epoch
        history['train_loss'].append(epoch_loss)
        epoch_metrics = {metric.__class__.__name__: metric.compute().item() for metric in metrics}
        eval_loss, eval_metrics = evaluate(device, model, eval_dataset, eval_batch_size, metrics,
                                           generate_preds_labels_fn)
        history['eval_loss'].append(eval_loss)
        print(f"Epoch {epoch + 1}/{num_train_epochs}")
        print(f"Train Loss: {epoch_loss:.4f}")
        update_metrics(history, epoch_metrics, 'train_metrics', epoch)
        print(f"Eval Loss: {eval_loss:.4f}")
        update_metrics(history, eval_metrics, 'eval_metrics', epoch)
    return history

# RESULTS

In [38]:
model_name_or_path = 'allenai/scibert_scivocab_uncased'
do_lower_case = True
base_tokenizer = BertTokenizer.from_pretrained(model_name_or_path, do_lower_case=do_lower_case)
extended_tokenizer = ExtendedBertTokenizer(base_tokenizer)

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataprocessor = DataProcessor()
max_seq_length = 510

In [40]:
train_ds = dataprocessor.load_examples(data_dir=data_dir, max_seq_length=max_seq_length,
                          tokenizer=extended_tokenizer)
val_ds = dataprocessor.load_examples( data_dir=data_dir, max_seq_length=max_seq_length,
                        tokenizer=extended_tokenizer, isval=True)
test_ds = dataprocessor.load_examples(data_dir=data_dir, max_seq_length=max_seq_length,
                        tokenizer=extended_tokenizer, evaluate=True)

# CREATE CLASSIFICATION+COUNT MODEL

In [41]:

num_class_labels = len(dataprocessor.get_labels())
"""
seq_tagging_config = BertConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class_labels
)
seq_tagging_model = BertForSequenceTagging.from_pretrained(
    model_name_or_path,
    config=seq_tagging_config
)
multiclass_metrics = [torchmetrics.Accuracy(num_classes=6, average='macro', task = "multiclass").to(device), torchmetrics.F1Score(num_classes=6, average='macro', task = "multiclass").to(device)]
history = train(device=device, train_dataset=train_ds, model=seq_tagging_model, eval_dataset=val_ds, generate_preds_labels_fn=get_classf_preds_labels, metrics=multiclass_metrics)

"""

'\nseq_tagging_config = BertConfig.from_pretrained(\n    model_name_or_path,\n    num_labels=num_class_labels\n)\nseq_tagging_model = BertForSequenceTagging.from_pretrained(\n    model_name_or_path,\n    config=seq_tagging_config\n)\nmulticlass_metrics = [torchmetrics.Accuracy(num_classes=6, average=\'macro\', task = "multiclass").to(device), torchmetrics.F1Score(num_classes=6, average=\'macro\', task = "multiclass").to(device)]\nhistory = train(device=device, train_dataset=train_ds, model=seq_tagging_model, eval_dataset=val_ds, generate_preds_labels_fn=get_classf_preds_labels, metrics=multiclass_metrics)\n\n'

In [42]:
classify_count_config = BertConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class_labels)
classify_count_base_model = BertForSequenceTagging.from_pretrained(
    model_name_or_path,
    config=classify_count_config
)
classify_and_count_model = ClassificationAndCounting(learner=classify_count_base_model, processor=dataprocessor)
classify_and_count_model.to(device)
dist_metrics = [torchmetrics.KLDivergence(log_prob=False, reduction="mean").to(device),
                torchmetrics.MeanAbsoluteError().to(device),
                torchmetrics.MeanSquaredError().to(device)]
history = train(device=device, train_dataset=train_ds, model=classify_and_count_model, eval_dataset=val_ds, generate_preds_labels_fn=get_quantization_preds_labels, metrics=dist_metrics)


Some weights of BertForSequenceTagging were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'rnn.bias_hh_l0', 'rnn.bias_hh_l0_reverse', 'rnn.bias_ih_l0', 'rnn.bias_ih_l0_reverse', 'rnn.weight_hh_l0', 'rnn.weight_hh_l0_reverse', 'rnn.weight_ih_l0', 'rnn.weight_ih_l0_reverse']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

TRAINING MODEL bert_classify_and_count



Iteration:   0%|          | 0/70 [00:00<?, ?it/s][A

Sequence output (BERT): -14.715432167053223, 17.918481826782227


  return torch.tensor(pred_proba).to(device),labels_proba


SINGLE LOSS: 2.534143844386498e+20



Iteration:   1%|▏         | 1/70 [00:04<05:16,  4.59s/it][A

Sequence output (BERT): -15.671463012695312, 17.03908348083496
SINGLE LOSS: 2.534143844386498e+20



Iteration:   3%|▎         | 2/70 [00:11<06:41,  5.91s/it][A

Sequence output (BERT): -15.093567848205566, 17.265661239624023
SINGLE LOSS: 2.53414067779301e+20



Iteration:   4%|▍         | 3/70 [00:15<05:30,  4.93s/it][A

Sequence output (BERT): -17.243614196777344, 17.31109619140625
SINGLE LOSS: 2.5341348723716155e+20



Iteration:   6%|▌         | 4/70 [00:21<06:13,  5.65s/it][A

Sequence output (BERT): -16.56304359436035, 17.92195701599121
SINGLE LOSS: 2.534125900356733e+20



Iteration:   7%|▋         | 5/70 [00:27<06:03,  5.59s/it][A

Sequence output (BERT): -14.891651153564453, 17.696239471435547
SINGLE LOSS: 2.5341135858265018e+20



Iteration:   9%|▊         | 6/70 [00:33<05:57,  5.59s/it][A

Sequence output (BERT): -17.54294204711914, 16.878068923950195
SINGLE LOSS: 2.534098280624643e+20



Iteration:  10%|█         | 7/70 [00:36<05:04,  4.83s/it][A

Sequence output (BERT): -17.714153289794922, 17.510231018066406
SINGLE LOSS: 2.534083151344645e+20



Iteration:  11%|█▏        | 8/70 [00:40<04:51,  4.70s/it][A

Sequence output (BERT): -17.389965057373047, 17.783634185791016
SINGLE LOSS: 2.5340680220646467e+20



Iteration:  13%|█▎        | 9/70 [00:48<05:36,  5.52s/it][A

Sequence output (BERT): -17.308269500732422, 17.728979110717773
SINGLE LOSS: 2.534052716862788e+20



Iteration:  14%|█▍        | 10/70 [00:56<06:17,  6.29s/it][A

Sequence output (BERT): -18.500185012817383, 17.827577590942383
SINGLE LOSS: 2.53403758758279e+20



Iteration:  16%|█▌        | 11/70 [01:00<05:38,  5.73s/it][A

Sequence output (BERT): -17.943553924560547, 17.277103424072266
SINGLE LOSS: 2.5340224583027917e+20



Iteration:  17%|█▋        | 12/70 [01:04<04:59,  5.17s/it][A

Sequence output (BERT): -17.4008846282959, 17.08854103088379
SINGLE LOSS: 2.5340073290227935e+20



Iteration:  19%|█▊        | 13/70 [01:07<04:20,  4.57s/it][A

Sequence output (BERT): -17.47627830505371, 17.1535701751709
SINGLE LOSS: 2.5339921997427953e+20



Iteration:  20%|██        | 14/70 [01:10<03:48,  4.08s/it][A

Sequence output (BERT): -17.040822982788086, 17.19114875793457
SINGLE LOSS: 2.5339768945409366e+20



Iteration:  21%|██▏       | 15/70 [01:13<03:22,  3.69s/it][A

Sequence output (BERT): -15.144891738891602, 16.745859146118164
SINGLE LOSS: 2.5339617652609384e+20



Iteration:  23%|██▎       | 16/70 [01:17<03:24,  3.79s/it][A

Sequence output (BERT): -17.739423751831055, 16.314838409423828
SINGLE LOSS: 2.5339466359809402e+20



Iteration:  24%|██▍       | 17/70 [01:20<03:12,  3.64s/it][A

Sequence output (BERT): -16.806970596313477, 17.811214447021484
SINGLE LOSS: 2.5339313307790816e+20



Iteration:  26%|██▌       | 18/70 [01:23<03:02,  3.52s/it][A

Sequence output (BERT): -16.530794143676758, 16.860170364379883
SINGLE LOSS: 2.5339162014990834e+20



Iteration:  27%|██▋       | 19/70 [01:26<02:51,  3.36s/it][A

Sequence output (BERT): -14.891195297241211, 16.54011344909668
SINGLE LOSS: 2.5339010722190852e+20



Iteration:  29%|██▊       | 20/70 [01:29<02:41,  3.24s/it][A

Sequence output (BERT): -17.376096725463867, 16.689987182617188
SINGLE LOSS: 2.533885942939087e+20



Iteration:  30%|███       | 21/70 [01:33<02:40,  3.27s/it][A

Sequence output (BERT): -16.15793800354004, 16.612966537475586
SINGLE LOSS: 2.5338708136590888e+20



Iteration:  31%|███▏      | 22/70 [01:39<03:15,  4.07s/it][A

Sequence output (BERT): -17.561309814453125, 16.444318771362305
SINGLE LOSS: 2.5338555084572302e+20



Iteration:  33%|███▎      | 23/70 [01:42<02:55,  3.73s/it][A

Sequence output (BERT): -17.671537399291992, 16.71509552001953
SINGLE LOSS: 2.533840379177232e+20



Iteration:  34%|███▍      | 24/70 [01:44<02:40,  3.49s/it][A

Sequence output (BERT): -17.29938507080078, 16.063871383666992
SINGLE LOSS: 2.5338252498972338e+20



Iteration:  36%|███▌      | 25/70 [01:48<02:33,  3.42s/it][A

Sequence output (BERT): -16.938488006591797, 15.96292495727539
SINGLE LOSS: 2.533809944695375e+20



Iteration:  37%|███▋      | 26/70 [01:55<03:24,  4.65s/it][A

Sequence output (BERT): -15.730973243713379, 16.14704704284668
SINGLE LOSS: 2.533794815415377e+20



Iteration:  39%|███▊      | 27/70 [01:58<02:58,  4.16s/it][A

Sequence output (BERT): -17.444814682006836, 15.888656616210938
SINGLE LOSS: 2.5337796861353787e+20



Iteration:  40%|████      | 28/70 [02:01<02:41,  3.84s/it][A

Sequence output (BERT): -16.692922592163086, 16.07956886291504
SINGLE LOSS: 2.5337645568553805e+20



Iteration:  41%|████▏     | 29/70 [02:04<02:29,  3.64s/it][A

Sequence output (BERT): -17.14959716796875, 15.720462799072266
SINGLE LOSS: 2.5337494275753823e+20



Iteration:  43%|████▎     | 30/70 [02:08<02:18,  3.46s/it][A

Sequence output (BERT): -17.30865478515625, 15.541646957397461
SINGLE LOSS: 2.5337341223735237e+20



Iteration:  44%|████▍     | 31/70 [02:11<02:09,  3.31s/it][A

Sequence output (BERT): -16.76171112060547, 16.06256103515625
SINGLE LOSS: 2.533719169015386e+20



Iteration:  46%|████▌     | 32/70 [02:13<02:00,  3.18s/it][A

Sequence output (BERT): -17.56482696533203, 16.0620174407959
SINGLE LOSS: 2.5337042156572482e+20



Iteration:  47%|████▋     | 33/70 [02:17<02:00,  3.26s/it][A

Sequence output (BERT): -17.183528900146484, 15.909631729125977
SINGLE LOSS: 2.5336892622991104e+20



Iteration:  49%|████▊     | 34/70 [02:20<01:58,  3.29s/it][A

Sequence output (BERT): -16.65809440612793, 15.411611557006836
SINGLE LOSS: 2.5336743089409727e+20



Iteration:  50%|█████     | 35/70 [02:23<01:52,  3.22s/it][A

Sequence output (BERT): -17.051801681518555, 15.821745872497559
SINGLE LOSS: 2.533659355582835e+20



Iteration:  51%|█████▏    | 36/70 [02:27<01:50,  3.25s/it][A

Sequence output (BERT): -16.95025062561035, 15.576169967651367
SINGLE LOSS: 2.5336444022246972e+20



Iteration:  53%|█████▎    | 37/70 [02:31<02:01,  3.70s/it][A

Sequence output (BERT): -14.38640308380127, 15.553511619567871
SINGLE LOSS: 2.5336294488665594e+20



Iteration:  54%|█████▍    | 38/70 [02:36<02:04,  3.91s/it][A

Sequence output (BERT): -17.725435256958008, 15.370058059692383
SINGLE LOSS: 2.5336144955084217e+20



Iteration:  56%|█████▌    | 39/70 [02:40<02:02,  3.94s/it][A

Sequence output (BERT): -16.452442169189453, 15.581738471984863
SINGLE LOSS: 2.533599542150284e+20



Iteration:  57%|█████▋    | 40/70 [02:45<02:07,  4.24s/it][A

Sequence output (BERT): -17.04998016357422, 15.420795440673828
SINGLE LOSS: 2.5335845887921462e+20



Iteration:  59%|█████▊    | 41/70 [02:49<02:05,  4.32s/it][A

Sequence output (BERT): -17.233415603637695, 14.919551849365234
SINGLE LOSS: 2.5335696354340084e+20



Iteration:  60%|██████    | 42/70 [02:53<01:58,  4.22s/it][A

Sequence output (BERT): -16.576765060424805, 15.161785125732422
SINGLE LOSS: 2.5335546820758707e+20



Iteration:  61%|██████▏   | 43/70 [02:57<01:49,  4.04s/it][A

Sequence output (BERT): -17.11048126220703, 14.799696922302246
SINGLE LOSS: 2.533539728717733e+20



Iteration:  63%|██████▎   | 44/70 [03:02<01:52,  4.33s/it][A

Sequence output (BERT): -16.872833251953125, 15.398710250854492
SINGLE LOSS: 2.533524775359595e+20



Iteration:  64%|██████▍   | 45/70 [03:08<01:59,  4.79s/it][A

Sequence output (BERT): -16.37388801574707, 14.881319999694824
SINGLE LOSS: 2.5335098220014574e+20



Iteration:  66%|██████▌   | 46/70 [03:14<02:08,  5.37s/it][A

Sequence output (BERT): -14.512471199035645, 15.207059860229492
SINGLE LOSS: 2.5334948686433196e+20



Iteration:  67%|██████▋   | 47/70 [03:19<01:57,  5.09s/it][A

Sequence output (BERT): -13.757569313049316, 14.557049751281738
SINGLE LOSS: 2.533479915285182e+20



Iteration:  69%|██████▊   | 48/70 [03:23<01:46,  4.85s/it][A

Sequence output (BERT): -16.855470657348633, 15.287718772888184
SINGLE LOSS: 2.533464961927044e+20



Iteration:  70%|███████   | 49/70 [03:28<01:43,  4.94s/it][A

Sequence output (BERT): -15.244176864624023, 14.825758934020996
SINGLE LOSS: 2.5334500085689064e+20



Iteration:  71%|███████▏  | 50/70 [03:34<01:40,  5.04s/it][A

Sequence output (BERT): -16.375181198120117, 15.610196113586426
SINGLE LOSS: 2.5334350552107686e+20



Iteration:  73%|███████▎  | 51/70 [03:39<01:37,  5.13s/it][A

Sequence output (BERT): -15.788134574890137, 14.931154251098633
SINGLE LOSS: 2.533420101852631e+20



Iteration:  74%|███████▍  | 52/70 [03:44<01:31,  5.09s/it][A

Sequence output (BERT): -15.755375862121582, 14.216964721679688
SINGLE LOSS: 2.5334053244163536e+20



Iteration:  76%|███████▌  | 53/70 [03:49<01:24,  4.98s/it][A

Sequence output (BERT): -16.79708480834961, 15.536029815673828
SINGLE LOSS: 2.5333907229019367e+20



Iteration:  77%|███████▋  | 54/70 [03:54<01:22,  5.14s/it][A

Sequence output (BERT): -15.521397590637207, 14.905611991882324
SINGLE LOSS: 2.5333759454656594e+20



Iteration:  79%|███████▊  | 55/70 [03:59<01:16,  5.11s/it][A

Sequence output (BERT): -15.918844223022461, 15.242597579956055
SINGLE LOSS: 2.533361168029382e+20



Iteration:  80%|████████  | 56/70 [04:03<01:08,  4.88s/it][A

Sequence output (BERT): -15.063349723815918, 15.274470329284668
SINGLE LOSS: 2.5333465665149652e+20



Iteration:  81%|████████▏ | 57/70 [04:09<01:04,  4.99s/it][A

Sequence output (BERT): -15.395509719848633, 14.381698608398438
SINGLE LOSS: 2.533331789078688e+20



Iteration:  83%|████████▎ | 58/70 [04:13<00:55,  4.65s/it][A

Sequence output (BERT): -15.066682815551758, 15.208032608032227
SINGLE LOSS: 2.5333170116424106e+20



Iteration:  84%|████████▍ | 59/70 [04:16<00:46,  4.18s/it][A

Sequence output (BERT): -16.391067504882812, 14.546611785888672
SINGLE LOSS: 2.5333022342061333e+20



Iteration:  86%|████████▌ | 60/70 [04:22<00:47,  4.76s/it][A

Sequence output (BERT): -14.393691062927246, 14.270003318786621
SINGLE LOSS: 2.533287456769856e+20



Iteration:  87%|████████▋ | 61/70 [04:29<00:48,  5.37s/it][A

Sequence output (BERT): -14.340975761413574, 14.55958366394043
SINGLE LOSS: 2.533272855255439e+20



Iteration:  89%|████████▊ | 62/70 [04:34<00:42,  5.28s/it][A

Sequence output (BERT): -15.71900749206543, 14.878067016601562
SINGLE LOSS: 2.5332580778191618e+20



Iteration:  90%|█████████ | 63/70 [04:38<00:35,  5.12s/it][A

Sequence output (BERT): -13.614900588989258, 13.65998649597168
SINGLE LOSS: 2.5332433003828845e+20



Iteration:  91%|█████████▏| 64/70 [04:44<00:31,  5.29s/it][A

Sequence output (BERT): -14.201986312866211, 14.102134704589844
SINGLE LOSS: 2.5332286988684676e+20



Iteration:  93%|█████████▎| 65/70 [04:50<00:27,  5.43s/it][A

Sequence output (BERT): -12.484776496887207, 13.939964294433594
SINGLE LOSS: 2.5332139214321903e+20



Iteration:  94%|█████████▍| 66/70 [04:53<00:19,  4.80s/it][A

Sequence output (BERT): -13.694860458374023, 13.550265312194824
SINGLE LOSS: 2.533199143995913e+20



Iteration:  96%|█████████▌| 67/70 [04:56<00:13,  4.34s/it][A

Sequence output (BERT): -12.239469528198242, 12.783123016357422
SINGLE LOSS: 2.5331843665596357e+20



Iteration:  97%|█████████▋| 68/70 [05:00<00:08,  4.07s/it][A

Sequence output (BERT): -13.015257835388184, 12.459232330322266
SINGLE LOSS: 2.5331695891233584e+20



Iteration:  99%|█████████▊| 69/70 [05:03<00:03,  3.77s/it][A

Sequence output (BERT): -11.853110313415527, 13.234981536865234
SINGLE LOSS: 2.5331549876089415e+20



Iteration: 100%|██████████| 70/70 [05:06<00:00,  4.38s/it]


CUMULATIVE LOSSES: 1.7735671432231466e+22



Evaluating:   0%|          | 0/2 [00:00<?, ?it/s][A

Sequence output (BERT): -12.511577606201172, 13.620159149169922



Evaluating:  50%|█████     | 1/2 [00:05<00:05,  5.92s/it][A

Sequence output (BERT): -12.847799301147461, 13.59028148651123



Evaluating: 100%|██████████| 2/2 [00:09<00:00,  4.59s/it]
Epoch:   3%|▎         | 1/30 [05:15<2:32:38, 315.82s/it]

Epoch 1/30
Train Loss: 253366734746163806208.0000
KLDivergence: 0.1494
MeanAbsoluteError: 0.1101
MeanSquaredError: 0.0188
Eval Loss: 1266570122678518153216.0000
KLDivergence: 0.0864
MeanAbsoluteError: 0.0861
MeanSquaredError: 0.0127



Iteration:   0%|          | 0/70 [00:00<?, ?it/s][A

Sequence output (BERT): -12.11197566986084, 12.757888793945312
SINGLE LOSS: 2.5331402101726642e+20



Iteration:   1%|▏         | 1/70 [00:03<03:46,  3.28s/it][A

Sequence output (BERT): -13.127034187316895, 12.045297622680664
SINGLE LOSS: 2.533125432736387e+20



Iteration:   3%|▎         | 2/70 [00:06<03:36,  3.19s/it][A

Sequence output (BERT): -11.412140846252441, 13.758880615234375
SINGLE LOSS: 2.53311083122197e+20



Iteration:   4%|▍         | 3/70 [00:09<03:27,  3.10s/it][A

Sequence output (BERT): -11.05139446258545, 12.98026180267334
SINGLE LOSS: 2.5330960537856927e+20



Iteration:   6%|▌         | 4/70 [00:12<03:19,  3.02s/it][A

Sequence output (BERT): -11.432028770446777, 13.282960891723633
SINGLE LOSS: 2.5330812763494154e+20



Iteration:   7%|▋         | 5/70 [00:15<03:19,  3.06s/it][A

Sequence output (BERT): -9.864572525024414, 13.075700759887695
SINGLE LOSS: 2.533066498913138e+20



Iteration:   9%|▊         | 6/70 [00:18<03:17,  3.09s/it][A

Sequence output (BERT): -12.537796020507812, 12.396660804748535
SINGLE LOSS: 2.5330517214768608e+20



Iteration:  10%|█         | 7/70 [00:21<03:12,  3.06s/it][A

Sequence output (BERT): -11.982110023498535, 12.428192138671875
SINGLE LOSS: 2.533037119962444e+20



Iteration:  11%|█▏        | 8/70 [00:24<03:15,  3.16s/it][A

Sequence output (BERT): -10.837635040283203, 11.934967994689941
SINGLE LOSS: 2.5330223425261666e+20



Iteration:  13%|█▎        | 9/70 [00:27<03:07,  3.07s/it][A

Sequence output (BERT): -13.48060417175293, 12.248676300048828
SINGLE LOSS: 2.5330075650898893e+20



Iteration:  14%|█▍        | 10/70 [00:30<03:02,  3.05s/it][A

Sequence output (BERT): -11.10561466217041, 12.552366256713867
SINGLE LOSS: 2.5329929635754725e+20



Iteration:  16%|█▌        | 11/70 [00:34<03:02,  3.09s/it][A

Sequence output (BERT): -11.248446464538574, 12.917572975158691
SINGLE LOSS: 2.532978186139195e+20



Iteration:  17%|█▋        | 12/70 [00:37<02:59,  3.10s/it][A

Sequence output (BERT): -11.750938415527344, 13.190348625183105
SINGLE LOSS: 2.532963408702918e+20



Iteration:  19%|█▊        | 13/70 [00:42<03:38,  3.84s/it][A

Sequence output (BERT): -10.665806770324707, 13.824458122253418
SINGLE LOSS: 2.5329486312666405e+20



Iteration:  20%|██        | 14/70 [00:45<03:22,  3.62s/it][A

Sequence output (BERT): -10.841449737548828, 13.348066329956055
SINGLE LOSS: 2.5329338538303632e+20



Iteration:  21%|██▏       | 15/70 [00:51<03:47,  4.14s/it][A

Sequence output (BERT): -10.234513282775879, 13.251004219055176
SINGLE LOSS: 2.5329192523159464e+20



Iteration:  23%|██▎       | 16/70 [00:54<03:24,  3.79s/it][A

Sequence output (BERT): -11.301548957824707, 13.906079292297363
SINGLE LOSS: 2.532904474879669e+20



Iteration:  24%|██▍       | 17/70 [00:57<03:08,  3.55s/it][A

Sequence output (BERT): -12.255313873291016, 13.9600248336792
SINGLE LOSS: 2.5328896974433917e+20



Iteration:  26%|██▌       | 18/70 [01:00<02:55,  3.37s/it][A

Sequence output (BERT): -11.14191722869873, 13.551706314086914
SINGLE LOSS: 2.532875095928975e+20



Iteration:  27%|██▋       | 19/70 [01:03<02:59,  3.52s/it][A

Sequence output (BERT): -12.265870094299316, 13.990854263305664
SINGLE LOSS: 2.5328603184926976e+20



Iteration:  29%|██▊       | 20/70 [01:06<02:48,  3.37s/it][A

Sequence output (BERT): -12.351841926574707, 13.68467903137207
SINGLE LOSS: 2.5328455410564202e+20



Iteration:  30%|███       | 21/70 [01:10<02:41,  3.31s/it][A

Sequence output (BERT): -10.817895889282227, 13.075139999389648
SINGLE LOSS: 2.532830763620143e+20



Iteration:  31%|███▏      | 22/70 [01:13<02:38,  3.29s/it][A

Sequence output (BERT): -11.410187721252441, 13.702019691467285
SINGLE LOSS: 2.5328159861838656e+20



Iteration:  33%|███▎      | 23/70 [01:16<02:38,  3.38s/it][A

Sequence output (BERT): -10.212099075317383, 13.67501163482666
SINGLE LOSS: 2.5328015605913092e+20



Iteration:  34%|███▍      | 24/70 [01:22<03:06,  4.06s/it][A

Sequence output (BERT): -11.759846687316895, 13.975378036499023
SINGLE LOSS: 2.5327871349987528e+20



Iteration:  36%|███▌      | 25/70 [01:25<02:52,  3.84s/it][A

Sequence output (BERT): -11.249343872070312, 13.304853439331055
SINGLE LOSS: 2.532772533484336e+20



Iteration:  37%|███▋      | 26/70 [01:31<03:07,  4.26s/it][A

Sequence output (BERT): -10.81851863861084, 13.409655570983887
SINGLE LOSS: 2.532757931969919e+20



Iteration:  39%|███▊      | 27/70 [01:34<02:51,  3.99s/it][A

Sequence output (BERT): -12.343364715576172, 13.725702285766602
SINGLE LOSS: 2.5327435063773626e+20



Iteration:  40%|████      | 28/70 [01:38<02:44,  3.91s/it][A

Sequence output (BERT): -13.245479583740234, 13.066840171813965
SINGLE LOSS: 2.5327290807848062e+20



Iteration:  41%|████▏     | 29/70 [01:43<02:56,  4.29s/it][A

Sequence output (BERT): -10.5408296585083, 14.101422309875488
SINGLE LOSS: 2.5327144792703894e+20



Iteration:  43%|████▎     | 30/70 [01:46<02:39,  3.98s/it][A

Sequence output (BERT): -8.50461196899414, 13.436603546142578
SINGLE LOSS: 2.5326998777559725e+20



Iteration:  44%|████▍     | 31/70 [01:50<02:33,  3.94s/it][A

Sequence output (BERT): -12.431669235229492, 12.71997356414795
SINGLE LOSS: 2.532685452163416e+20



Iteration:  46%|████▌     | 32/70 [01:54<02:28,  3.92s/it][A

Sequence output (BERT): -11.698928833007812, 12.521946907043457
SINGLE LOSS: 2.5326710265708596e+20



Iteration:  47%|████▋     | 33/70 [01:57<02:17,  3.73s/it][A

Sequence output (BERT): -11.477239608764648, 14.139301300048828
SINGLE LOSS: 2.5326564250564428e+20



Iteration:  49%|████▊     | 34/70 [02:01<02:11,  3.66s/it][A

Sequence output (BERT): -11.5579833984375, 12.801414489746094
SINGLE LOSS: 2.532641823542026e+20



Iteration:  50%|█████     | 35/70 [02:04<02:02,  3.50s/it][A

Sequence output (BERT): -11.6135835647583, 14.076586723327637
SINGLE LOSS: 2.5326273979494695e+20



Iteration:  51%|█████▏    | 36/70 [02:07<01:53,  3.33s/it][A

Sequence output (BERT): -10.761434555053711, 14.052170753479004
SINGLE LOSS: 2.532612972356913e+20


Iteration:  51%|█████▏    | 36/70 [02:10<02:02,  3.62s/it]
Epoch:   3%|▎         | 1/30 [07:25<3:35:33, 445.97s/it]


KeyboardInterrupt: 

In [None]:
# metrics = [torchmetrics.Accuracy(num_classes=3, average='macro', task = "multiclass"),
#      torchmetrics.F1Score(num_classes=3, average='macro', task = "multiclass")]
# history = train(device=device, train_dataset=train_ds, model=wrapper, eval_dataset=val_ds,
#                generate_preds_labels_fn=classification_preds_labels_fn, metrics=metrics)
# _ , test_metric = evaluate(device = device, model = model,eval_dataset= test_ds, eval_batch_size= 32,
# gen_preds_labels_fn=classification_preds_labels_fn)
losses = {"kldiv": kldiv_loss,
          "mae" : mae_loss}
quantify_models = []
for loss_name, loss in losses.items():
    num_quantify_classes = 3
    quantify_config = BertConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_quantify_classes)
    quantify_model = BertForLabelDistribution.from_pretrained(model_name_or_path, config=quantify_config, loss_fn=loss, loss_fn_name =loss_name )
    quantify_model.to(device)
    quantify_models.append(quantify_model)

histories = {}
test_metrics = {}
for quantify_model in quantify_models:
    qua_metrics = [torchmetrics.KLDivergence(log_prob=False, reduction="mean").to(device),
                torchmetrics.MeanAbsoluteError().to(device),
                torchmetrics.MeanSquaredError().to(device)]
    history = train(device=device, train_dataset=train_ds, model=quantify_model, eval_dataset=val_ds,
                    generate_preds_labels_fn=get_quantization_preds_labels, metrics=qua_metrics)
    histories[quantify_model.name] = history
    test_metric = evaluate(device, quantify_model, test_ds, eval_batch_size=32, metrics=qua_metrics,
                            gen_preds_labels_fn=get_quantization_preds_labels)
    test_metrics[quantify_model.name] = test_metric