#TODO
1. Costruisci un classifier in pytorch (o in tensorflow se riesci), che prende in input un abstract, espresso tramite word embedding, e ritorna, per ogni parola nell'abstract, una label che indica il tipo di classe a cui quella appartiene (BIO tagging).
 * **PROBLEMA**: capire come funziona internamente il "Classify and Count" di QuaPy, perche' nel tuo caso devi creare una classe simile a  "Classify and Count" che pero', nella funzione "aggregate", riduce **"B-Claim, I-Claim, B-Evidence, I-Evidence e Outside"** in sole 3 classi **"Claim, Evidence e Outside"**.
2. Utilizza dei metodi presenti nel paper, che fanno *quantization* senza fare classification+counting
3. Paragona i risultati di questi due approcci, su diverse metriche.

# IMPORTANTE per QuaPy
1. Vedi come fare Model Selection usando la GridSearch di QuaPy
2. Vedi come testare i tuoi modelli, attraverso l'uso di  *natural prevalence protocol (NPP)* e  *artificial prevalence protocol (APP)*
3. Vedi come fare dei grafici, da cui e' possibile discutere i risultati
4. Vedi come poter usare QuaNet

`QuaPy provides an implementation of QuaNet, a deep-learning-based method for performing quantification on samples of textual documents, presented in [8].8 QuaNet processes as input a list of document embeddings (see below),
one for each unlabelled document along with their posterior probabilities generated by a probabilistic classifier. The
list is processed by a bidirectional LSTM that generates a sample embedding (i.e., a dense representation of the entire
sample), which is then concatenated with a vector of class prevalence estimates produced by an ensemble of simpler
quantification methods (CC, ACC, PCC, PACC, SLD). This vector is then transformed by a set of feed-forward layers,
followed by ReLU activations and dropout, to compute the final estimations.`

QuaNet thus requires a probabilistic classifier that can provide embedded representations of the inputs. QuaPy offers
a basic implementation of such a classifier, based on convolutional neural networks, that returns its next-to-last representation as the document embedding. The following is a working example showing how to index a textual dataset
(see Section 3) and how to instantiate QuaNet:
```python
1 import quapy as qp
2 from quapy.method.meta import QuaNet
3 from classification.neural import NeuralClassifierTrainer, CNNnet
4
5 qp.environ['SAMPLE_SIZE'] = 500
6
7 # load the kindle dataset as plain text, and convert words
8 # to numerical indexes
9 dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
10 qp.data.preprocessing.index(dataset, min_df=5, inplace=True)
11
12 cnn = CNNnet(dataset.vocabulary_size, dataset.n_classes)
13 learner = NeuralClassifierTrainer(cnn, device='cuda')
14 model = QuaNet(learner, sample_size=500, device='cuda')
```
QuaNet paper link : Esuli, A., Moreo, A., & Sebastiani, F. (2018, October). A recurrent neural network for sentiment quantification. In Proceedings of the 27th ACM International Conference on Information and Knowledge Management (pp. 1775-1778).
# IMPORTANTE per Preprocessing
1. Vedi quale Word Embedding utilizzare
2. Vedi se e come utilizzare Bert per generare degli abstract embeddings
3. Vedi se ha senso usare altri tipi, come SciBert (che e' stato allenato su papers scientifici)


PROBLEMA IMPORTANTISSIMO

In [1]:
!pip install transformers pytorch-crf==0.7.2 torchmetrics



In [2]:
import csv
import sys
import os
import collections
import random
from tqdm import trange,tqdm

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, Subset
from transformers import BasicTokenizer, BertPreTrainedModel, BertModel, BertConfig,  BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
import numpy as np
from torch import nn
from torchcrf import CRF
import torchmetrics

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
data_dir = '/content/drive/My Drive/NLPProjectWork/data/neoplasm'
# Ensure the directory exists
if os.path.exists(data_dir):
    os.chdir(data_dir)
    print(f"Current directory: {os.getcwd()}")
else:
    print("Directory not found.")

Current directory: /content/drive/My Drive/NLPProjectWork/data/neoplasm


In [5]:
def set_seed(seed, n_gpu):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

In [6]:
n_gpu = torch.cuda.device_count()
set_seed(seed = 2, n_gpu = n_gpu)

# PREPROCESSING


In [7]:
class InputExample(object):
    def __init__(self, guid, text, labels=None):
        self.guid = guid
        self.text = text
        self.labels = labels


class InputFeatures(object):
    def __init__(self, input_ids, input_mask, segment_ids, label_ids, label_proba):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.label_proba = label_proba

In [8]:
class DataProcessor(object):
    def __init__(self):
        self.labels = ["X", "B-Claim", "I-Claim", "B-Premise", "I-Premise", 'O']
        self.label_map = self._create_label_map()
        self.replace_labels = {
            'B-MajorClaim': 'B-Claim',
            'I-MajorClaim': 'I-Claim',
        }

    def _create_label_map(self):
        label_map = collections.OrderedDict()
        for i, label in enumerate(self.labels):
            label_map[label] = i
        return label_map

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, "train.conll"), replace=self.replace_labels), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, "dev.conll"), replace=self.replace_labels), "dev")

    def get_test_examples(self, data_dir, setname="test.conll"):
        """See base class."""
        return self._create_examples(
            self.read_conll(os.path.join(data_dir, setname), replace=self.replace_labels), "test")

    def get_labels(self):
        """ See base class."""
        return self.labels

    def convert_labels_to_ids(self, labels):
        idx_list = []
        for label in labels:
            idx_list.append(self.label_map[label])
        return idx_list

    def convert_ids_to_labels(self, idx_list):
        labels_list = []
        for idx in idx_list:
            labels_list.append([key for key in self.label_map.keys() if self.label_map[key] == idx][0])
        return labels_list

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, str(i))
            text = line[0]
            labels = line[-1]
            examples.append(
                InputExample(guid=guid, text=text, labels=labels))
        return examples

    def convert_examples_to_features(self, examples, max_seq_length, tokenizer,
                                     cls_token='[CLS]',
                                     sep_token='[SEP]'):
        """Loads a data file into a list of `InputBatch`s."""

        features = []
        for (ex_index, example) in enumerate(examples):
            tokens_a, labels = tokenizer.tokenize_with_label_extension(example.text, example.labels,
                                                                       copy_previous_label=True)

            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:(max_seq_length - 2)]
                labels = labels[:(max_seq_length - 2)]
            labels = ["X"] + labels + ["X"]

            tokens = [cls_token] + tokens_a + [sep_token]
            segment_ids = [0] * len(tokens)
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)
            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding
            label_ids = self.convert_labels_to_ids(labels)
            # Se il mask e' None allora non fare nulla
            # Altrimenti maska sta label
            label_proba = self.compute_label_probadist(label_ids)
            # leva il primo e l'ultimo elemento
            # assert(non contiene 0 all'interno)
            # nel caso delle predizioni se contiene 0 che devo fare? Direi di skippare tutti gli zeri dal calcolo della proba di dist

            label_ids += padding

            assert len(label_ids) == max_seq_length
            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_ids=label_ids,
                              label_proba=label_proba))
        return features

    def compute_label_probadist(self, label_ids, mask=None):
        label_ids_np = np.array(label_ids)
        if mask is not None:
            mask_np = mask.cpu().detach().numpy()
            label_ids_np = label_ids_np[mask_np == 1]
        label_ids_np = label_ids_np[1:-1]
        label_ids_np = label_ids_np[label_ids_np != 0]
        if len(label_ids_np) == 0:
            return [0.0, 0.0, 0.0]
        assert (0 not in label_ids_np)
        map = {self.label_map["B-Claim"]: 0,
               self.label_map["I-Claim"]: 0,
               self.label_map["B-Premise"]: 1,
               self.label_map["I-Premise"]: 1,
               self.label_map["O"]: 2}
        remapped_label_ids = [map[label] for label in label_ids_np]
        counts = np.bincount(remapped_label_ids, minlength=3)
        probabilities = counts / len(remapped_label_ids)
        return probabilities

    @classmethod
    def features_to_dataset(cls, feature_list):
        all_input_ids = torch.tensor([f.input_ids for f in feature_list], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in feature_list], dtype=torch.uint8)
        all_segment_ids = torch.tensor([f.segment_ids for f in feature_list], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_ids for f in feature_list], dtype=torch.long)
        all_label_probas = torch.tensor([np.array(f.label_proba) for f in feature_list])
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_label_probas)
        return dataset

    @classmethod
    def read_conll(cls, input_file, token_number=0, token_column=1, label_column=4, replace=None):
        """Reads a conll type file."""
        with open(input_file, "r", encoding='utf-8') as f:
            lines = f.readlines()
            sentences = []
            tokenizer = BasicTokenizer()
            sent_tokens = []
            sent_labels = []

            for idx, line in enumerate(lines):

                line = line.split('\t')
                # Skip the lines in which there is /n
                if len(line) < 2:
                    continue

                # Controllare se sono all'ID numero 1, assicurandomi nel mentre che NON sia la prima iterazione
                if idx != 0 and int(line[token_number]) == 0:
                    assert len(sent_tokens) == len(sent_labels)
                    if replace:
                        sent_labels = [replace[label] if label in replace.keys() else label for label in sent_labels]
                    sentences.append([' '.join(sent_tokens), sent_labels])
                    sent_tokens = []
                    sent_labels = []

                token = line[token_column]
                label = line[label_column].replace('\n', '')
                tokenized = tokenizer.tokenize(token)

                if len(tokenized) > 1:
                    for i in range(len(tokenized)):
                        sent_tokens.append(tokenized[i])
                        sent_labels.append(label)
                else:
                    sent_tokens.append(tokenized[0])
                    sent_labels.append(label)
            if sent_tokens != []:
                assert len(sent_tokens) == len(sent_labels)
                if replace:
                    sent_labels = [replace[label] if label in replace.keys() else label for label in sent_labels]
                sentences.append([' '.join(sent_tokens), sent_labels])
        return sentences

    def load_examples(self, data_dir, max_seq_length, tokenizer, evaluate=False, isval=False):
        if evaluate:
            examples = self.get_test_examples(data_dir)
        elif isval:
            examples = self.get_dev_examples(data_dir)
        else:
            examples = self.get_train_examples(data_dir)
        features = self.convert_examples_to_features(examples, max_seq_length=max_seq_length, tokenizer=tokenizer)
        dataset = self.features_to_dataset(features)
        return dataset

In [9]:
class ExtendedBertTokenizer:
    """Extended tokenizer that wraps a base tokenizer."""

    def __init__(self, base_tokenizer):
        self.tokenizer = base_tokenizer

    def tokenize_with_label_extension(self, text, labels, copy_previous_label=False, extension_label='X'):
        tok_text = self.tokenizer.tokenize(text)
        for i in range(len(tok_text)):
            if '##' in tok_text[i]:
                if copy_previous_label:
                    labels.insert(i, labels[i - 1])
                else:
                    labels.insert(i, extension_label)
        return tok_text, labels

    def convert_tokens_to_ids(self, tokens):
      return self.tokenizer.convert_tokens_to_ids(tokens)

# MODELS

In [10]:
num_labels = 3  # Let's say we have 3 labels (0, 1, 2)
batch_size = 1
seq_length = 5

emissions = torch.tensor([[[ 0.3427, -0.6740, -0.6637],
         [-0.6571,  0.0615, -0.1830],
         [ 0.0216,  0.1059, -0.7968],
         [ 0.3336,  0.0276,  0.5800],
         [ 0.3955, -0.4380, -0.3168]]])
emissions.shape

torch.Size([1, 5, 3])

In [11]:
# Instantiate the CRF layer
crf = CRF(num_labels, batch_first=True)

# Create a tensor for labels (ground truth)
# Shape: (batch_size, seq_length)
# Using -1 for padding (not a real label)
labels = torch.tensor([[0, 1, -1, -1, -1],
                       ])

In [12]:
loss = -crf(emissions, labels, reduction='mean')  # Negative log likelihood loss
print("CRF Loss:", loss.item())

# Decode the predicted labels (best path)
decoded_labels = crf.decode(emissions)
print("Decoded Labels:", decoded_labels)

CRF Loss: 5.336999893188477
Decoded Labels: [[0, 1, 1, 2, 0]]


  score = torch.where(mask[i].unsqueeze(1), next_score, score)


In [13]:
class BertForSequenceTagging(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.crf = CRF(config.num_labels, batch_first=True)
        self.classifier = nn.Linear(2 * config.hidden_size, config.num_labels)
        self.custom_init_weights()
        self.name = "seq_tag_model"

    def custom_init_weights(self):
        # Initialize weights of GRU
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=0.01)
            elif 'bias' in name:
                nn.init.zeros_(param)  # Initialize biases to zero

        # Initialize weights of CRF transitions
        nn.init.normal_(self.crf.start_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.end_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.transitions, mean=0, std=0.1)

        # Initialize weights of the classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)
    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs[0]

        print(f"Sequence output (BERT): {sequence_output.min()}, {sequence_output.max()}")
        rnn_out, _ = self.rnn(sequence_output)
        emissions = self.classifier(rnn_out)

        mask = attention_mask.bool()
        log_likelihood = self.crf(emissions, labels, mask=mask)
        path = self.crf.decode(emissions)
        path = torch.LongTensor(path)
        return (-1 * log_likelihood, emissions, path, None)

In [38]:
class BertForSequenceTaggingAndQuantification(BertPreTrainedModel):
    def __init__(self, config, distribution_loss_fn, loss_fn_name):
        super().__init__(config, lambda_tagging=0.0001, lambda_distribution=1.0)
        self.num_labels = config.num_labels  # 6 labels for sequence tagging (Claim, Premise, etc.)
        self.num_proba_labels = 3  # 3 labels for probability distribution (Claim, Premise, Not Relevant)

        self.bert = BertModel(config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.crf = CRF(config.num_labels, batch_first=True)
        self.classifier = nn.Linear(2 * config.hidden_size, config.num_labels)  # For sequence tagging

        # For the quantification task
        self.projection_head = nn.Sequential(
            nn.Linear(config.num_labels * config.max_length, 128),  # Concatenate emissions across sequence
            nn.ReLU(),                                                  # Non-linear activation
            nn.Linear(128, self.num_proba_labels)  # Output 3 classes (Claim, Premise, Not Relevant)
        )
        self.distribution_loss_fn = distribution_loss_fn
        self.lambda_tagging = 0.0001
        self.lambda_distribution = 1.0

        self.custom_init_weights()
        self.name = f"seq_tag_quant_model_{loss_fn_name}"

    def custom_init_weights(self):
        # Initialize weights of GRU
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=0.01)
            elif 'bias' in name:
                nn.init.zeros_(param)  # Initialize biases to zero

        # Initialize weights of CRF transitions
        nn.init.normal_(self.crf.start_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.end_transitions, mean=0, std=0.1)
        nn.init.normal_(self.crf.transitions, mean=0, std=0.1)

        # Initialize weights of the classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

        # Initialize weights of the projection head
        for layer in self.projection_head:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None  # True distribution of Claim, Premise, Not Relevant
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)

        rnn_out, _ = self.rnn(sequence_output)  # Shape: (batch_size, sequence_length, 2 * hidden_size)
        emissions = self.classifier(rnn_out)    # Shape: (batch_size, sequence_length, num_labels)

        mask = attention_mask.bool()
        log_likelihood = self.crf(emissions, labels, mask=mask)  # Sequence tagging loss
        path = self.crf.decode(emissions)  # CRF-decoded sequence
        path = torch.LongTensor(path)

        # Part 2: Quantification (Predict probability distribution)

        # 1. Flatten emissions along the sequence dimension (this is where we avoid pooling)
        # Emissions shape: (batch_size, sequence_length, num_labels)
        # After flattening: (batch_size, sequence_length * num_labels)
        flattened_emissions = emissions.view(emissions.size(0), -1)  # Concatenates emissions along sequence

        # 2. Non-linear projection to 3-class probability distribution
        predicted_distribution = self.projection_head(flattened_emissions)  # Shape: (batch_size, num_proba_labels)

        # 3. Apply softmax to get a valid probability distribution
        predicted_distribution = nn.Softmax(dim=-1)(predicted_distribution)

        # Loss for the distribution prediction (e.g., KLDiv or MAE)
        distribution_loss = None
        # Use Kullback-Leibler divergence or Mean Absolute Error
        distribution_loss = distribution_loss_fn(predicted_distribution.log(), labels_proba)

        total_loss =self.lambda_tagging* -log_likelihood + self.lambda_distribution * distribution_loss  # Sequence tagging loss

        return (total_loss, emissions, path, predicted_distribution)

In [15]:
class ClassificationAndCounting:
    def __init__(self, learner, processor: DataProcessor):
        self.learner = learner
        self.processor = processor
        self.name = "bert_classify_and_count"

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None
    ):

        loss, emissions, path, _ = self.learner(input_ids, attention_mask, token_type_ids, labels, labels_proba)

        proba_dists = self.aggregate(path, attention_mask)
        return loss, emissions, path, proba_dists

    def aggregate(self, predictions, attention_mask):
        all_probabilities = []
        for sequence, mask in zip(predictions, attention_mask):
            label_proba = self.processor.compute_label_probadist(sequence, mask=mask)
            all_probabilities.append(label_proba)
        return np.array(all_probabilities)

    # Wrapper methods mimicking PyTorch model behavior
    def train(self, mode=True):
        self.learner.train(mode)

    def eval(self):
        self.learner.eval()

    def state_dict(self, *args, **kwargs):
        return self.learner.state_dict(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
        return self.learner.load_state_dict(*args, **kwargs)

    def to(self, *args, **kwargs):
        self.learner.to(*args, **kwargs)

    def parameters(self):
        return self.learner.parameters()

    def named_parameters(self):
        return self.learner.named_parameters()

    def zero_grad(self):
        self.learner.zero_grad()

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

In [16]:
class BertForLabelDistribution(BertPreTrainedModel):
    def __init__(self, config, loss_fn, loss_fn_name):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(2 * config.hidden_size, config.num_labels)
        self.softmax = nn.Softmax(dim=-1)
        self.init_weights()
        self.loss_fn = loss_fn
        self.name = f"bert_quantify_{loss_fn_name}"
        self.custom_init_weights()

    def custom_init_weights(self):
        # Initialize weights of GRU
        for name, param in self.rnn.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param, gain=0.01)
            elif 'bias' in name:
                nn.init.zeros_(param)  # Initialize biases to zero

        # Initialize weights of the classifier
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            labels=None,
            labels_proba=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        sequence_output = outputs[0]  # [batch_size, seq_len, hidden_size]
        rnn_out, _ = self.rnn(sequence_output)  # [batch_size, seq_len, 2*hidden_size]
        emissions = self.classifier(rnn_out)  # [batch_size, seq_len, num_labels]
        cls_emissions = emissions[:, 0, :]  # [batch_size, num_labels]
        # posos calcolare la loss e dopo la softmax
        cls_probs = self.softmax(cls_emissions)  # [batch_size, num_labels]
        loss = self.loss_fn(cls_probs, labels_proba)
        return loss, emissions, None, cls_probs

# TRAIN/EVALUATE

In [17]:
def get_quantization_preds_labels(inputs, outputs):
    _, _, _, pred_proba = outputs
    pred_proba = torch.tensor(pred_proba).to(device)
    labels_proba = inputs["labels_proba"]
    return pred_proba,labels_proba

def get_classf_preds_labels(inputs, outputs):
    loss, emissions, path, _ = outputs
    batch_logits = path.detach().cpu().numpy().flatten()
    batch_labels = inputs["labels"].detach().cpu().numpy().flatten()
    attention_mask = inputs["attention_mask"].detach().cpu().numpy().flatten()
    valid_batch_logits = batch_logits[attention_mask == 1]
    valid_batch_labels = batch_labels[attention_mask == 1]
    return torch.tensor(valid_batch_logits), torch.tensor(valid_batch_labels)

In [18]:
def mae_loss(pred_probas, label_probas):
    loss = nn.L1Loss(reduction="mean")
    return loss(pred_probas, label_probas)

def kldiv_loss(pred_probas, label_probas):
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    pred_probas_log = torch.log(pred_probas)
    output = kl_loss(pred_probas_log, label_probas)
    return output

In [19]:
def process_batch(batch, device, model, gen_preds_labels_fn, eval=False):
    model.eval() if eval else model.train()
    batch = tuple(t.to(device) for t in batch)
    inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": batch[3],
              "labels_proba": batch[4]}
    with torch.no_grad() if eval else torch.enable_grad():
        outputs = model(**inputs)
    preds, labels = gen_preds_labels_fn(inputs, outputs)
    return outputs, preds, labels

In [20]:
def evaluate(device, model, eval_dataset, eval_batch_size, metrics, gen_preds_labels_fn):
    eval_dataloader = DataLoader(eval_dataset, batch_size=eval_batch_size)
    epoch_loss_sum = 0.0
    for metric in metrics:
        metric.reset()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            outputs, preds, labels = process_batch(batch, device, model, gen_preds_labels_fn, eval=True)
            loss = outputs[0]
            epoch_loss_sum += loss.item()
            for metric in metrics:
                if metric.__class__.__name__ == "KLDivergence":
                      labels = labels.clamp(min=1e-9)
                      preds = preds.clamp(min=1e-9)
                metric.update(preds, labels)
    eval_loss = epoch_loss_sum / len(eval_dataloader)
    eval_metrics = {metric.__class__.__name__: metric.compute().item() for metric in metrics}
    return eval_loss, eval_metrics

In [21]:
def update_metrics(history, metrics, metric_type, epoch):
    for metric_name, metric_value in metrics.items():
        if epoch == 0:
            history[metric_type][metric_name] = [metric_value]
        else:
            history[metric_type][metric_name].append(metric_value)
        print(f"{metric_name}: {metric_value:.4f}")
def train(
        device,
        train_dataset,
        model,
        eval_dataset,
        generate_preds_labels_fn,
        metrics,
        num_train_epochs=30,
        train_batch_size=2,
        eval_batch_size=32,
        weight_decay=0.3,
        learning_rate=2e-5,
        adam_epsilon=1e-8,
        warmup_steps=5,
        max_grad_norm = 3.0
):
    history = {
        'train_loss': [],
        'train_metrics': {},
        'eval_loss': [],
        'eval_metrics': {}
    }
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size)
    num_steps_per_epoch = len(train_dataloader)
    t_total = num_steps_per_epoch * num_train_epochs
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
    )
    epochs_trained = 0

    train_iterator = trange(epochs_trained, int(num_train_epochs), desc="Epoch")
    print(f"TRAINING MODEL {model.name}")
    for epoch in train_iterator:
        epoch_loss_sum = 0.0
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            optimizer.zero_grad()
            outputs, preds, labels = process_batch(batch, device, model, generate_preds_labels_fn, eval=False)
            loss = outputs[0]
            print(f"SINGLE LOSS: {loss}")
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            epoch_loss_sum += loss.item()
            for metric in metrics:
                if metric.__class__.__name__ == "KLDivergence":
                      labels = labels.clamp(min=1e-9)
                      preds = preds.clamp(min=1e-9)
                metric.update(preds, labels)
            optimizer.step()
            scheduler.step()

        print(f"CUMULATIVE LOSSES: {epoch_loss_sum}")
        epoch_loss = epoch_loss_sum / num_steps_per_epoch
        history['train_loss'].append(epoch_loss)
        epoch_metrics = {metric.__class__.__name__: metric.compute().item() for metric in metrics}
        eval_loss, eval_metrics = evaluate(device, model, eval_dataset, eval_batch_size, metrics,
                                           generate_preds_labels_fn)
        history['eval_loss'].append(eval_loss)
        print(f"Epoch {epoch + 1}/{num_train_epochs}")
        print(f"Train Loss: {epoch_loss:.4f}")
        update_metrics(history, epoch_metrics, 'train_metrics', epoch)
        print(f"Eval Loss: {eval_loss:.4f}")
        update_metrics(history, eval_metrics, 'eval_metrics', epoch)
    return history

# RESULTS

In [22]:
model_name_or_path = 'allenai/scibert_scivocab_uncased'
do_lower_case = True
base_tokenizer = BertTokenizer.from_pretrained(model_name_or_path, do_lower_case=do_lower_case)
extended_tokenizer = ExtendedBertTokenizer(base_tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataprocessor = DataProcessor()
max_seq_length = 510

In [24]:
train_ds = dataprocessor.load_examples(data_dir=data_dir, max_seq_length=max_seq_length,
                        tokenizer=extended_tokenizer)
val_ds = dataprocessor.load_examples( data_dir=data_dir, max_seq_length=max_seq_length,
                        tokenizer=extended_tokenizer, isval=True)
test_ds = dataprocessor.load_examples(data_dir=data_dir, max_seq_length=max_seq_length,
                        tokenizer=extended_tokenizer, evaluate=True)

  all_label_probas = torch.tensor([np.array(f.label_proba) for f in feature_list])


# CREATE CLASSIFICATION+COUNT MODEL

In [25]:
num_class_labels = len(dataprocessor.get_labels())

"""
seq_tagging_config = BertConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class_labels
)
seq_tagging_model = BertForSequenceTagging.from_pretrained(
    model_name_or_path,
    config=seq_tagging_config
)
multiclass_metrics = [torchmetrics.Accuracy(num_classes=6, average='macro', task = "multiclass").to(device), torchmetrics.F1Score(num_classes=6, average='macro', task = "multiclass").to(device)]
history = train(device=device, train_dataset=train_ds, model=seq_tagging_model, eval_dataset=val_ds, generate_preds_labels_fn=get_classf_preds_labels, metrics=multiclass_metrics)
"""

'\nseq_tagging_config = BertConfig.from_pretrained(\n    model_name_or_path,\n    num_labels=num_class_labels\n)\nseq_tagging_model = BertForSequenceTagging.from_pretrained(\n    model_name_or_path,\n    config=seq_tagging_config\n)\nmulticlass_metrics = [torchmetrics.Accuracy(num_classes=6, average=\'macro\', task = "multiclass").to(device), torchmetrics.F1Score(num_classes=6, average=\'macro\', task = "multiclass").to(device)]\nhistory = train(device=device, train_dataset=train_ds, model=seq_tagging_model, eval_dataset=val_ds, generate_preds_labels_fn=get_classf_preds_labels, metrics=multiclass_metrics)\n'

In [22]:
classify_count_config = BertConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class_labels)
classify_count_base_model = BertForSequenceTagging.from_pretrained(
    model_name_or_path,
    config=classify_count_config
)
classify_and_count_model = ClassificationAndCounting(learner=classify_count_base_model, processor=dataprocessor)
classify_and_count_model.to(device)
dist_metrics = [torchmetrics.KLDivergence(log_prob=False, reduction="mean").to(device),
                torchmetrics.MeanAbsoluteError().to(device),
                torchmetrics.MeanSquaredError().to(device)]
history = train(device=device, train_dataset=train_ds, model=classify_and_count_model, eval_dataset=val_ds, generate_preds_labels_fn=get_quantization_preds_labels, metrics=dist_metrics)

Some weights of BertForSequenceTagging were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'rnn.bias_hh_l0', 'rnn.bias_hh_l0_reverse', 'rnn.bias_ih_l0', 'rnn.bias_ih_l0_reverse', 'rnn.weight_hh_l0', 'rnn.weight_hh_l0_reverse', 'rnn.weight_ih_l0', 'rnn.weight_ih_l0_reverse']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

TRAINING MODEL bert_classify_and_count



Iteration:   0%|          | 0/175 [00:00<?, ?it/s][A

Sequence output (BERT): -15.10572624206543, 16.894390106201172


  score = torch.where(mask[i].unsqueeze(1), next_score, score)
  return torch.tensor(pred_proba).to(device),labels_proba


SINGLE LOSS: 1.5419175033895463e+35



Iteration:   1%|          | 1/175 [00:01<05:26,  1.88s/it][A

Sequence output (BERT): -13.15503215789795, 16.602663040161133
SINGLE LOSS: 1.2212381484161545e+35



Iteration:   1%|          | 2/175 [00:02<03:37,  1.26s/it][A

Sequence output (BERT): -17.15949821472168, 16.852718353271484
SINGLE LOSS: 1.3150639018032658e+35



Iteration:   2%|▏         | 3/175 [00:03<03:01,  1.05s/it][A

Sequence output (BERT): -16.76675033569336, 16.991296768188477
SINGLE LOSS: 1.245897715928313e+35



Iteration:   2%|▏         | 4/175 [00:04<02:43,  1.05it/s][A

Sequence output (BERT): -14.336128234863281, 15.983817100524902
SINGLE LOSS: 1.4710946566364202e+35


Iteration:   2%|▏         | 4/175 [00:04<03:33,  1.25s/it]
Epoch:   0%|          | 0/30 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import os
import torch
import pickle

# Create directory if it doesn't exist
save_dir = "content/save"
classify_and_counnt_dir = os.path.join(save_dir, "classify_and_count")
os.makedirs(classify_and_counnt_dir, exist_ok=True)

# Save model's state_dict
torch.save(classify_and_count_model.state_dict(), os.path.join(classify_and_counnt_dir, 'classify_and_count_model.pth'))

# Save the configuration (if needed)
classify_count_config.save_pretrained(classify_and_counnt_dir)

In [None]:
# Load the configuration
classify_count_config = BertConfig.from_pretrained(classify_and_counnt_dir)

# Load the model and its state_dict
classify_count_base_model = BertForSequenceTagging.from_pretrained(
    model_name_or_path,  # Replace this with the same model path you used for saving
    config=classify_count_config
)

# Initialize the ClassificationAndCounting model
classify_and_count_model = ClassificationAndCounting(learner=classify_count_base_model, processor=dataprocessor)

# Load the state dictionary
classify_and_count_model.load_state_dict(torch.load(os.path.join(classify_and_counnt_dir, 'classify_and_count_model.pth')))

# Move model to the appropriate device (CPU/GPU)
classify_and_count_model.to(device)


How would you use THIS idea of predicting each example, and then using this prediction to then compute a distribution, BUT in a different scenario:
In a scenario in which you have a dataset composed of abstracts, and where I want to predict, for each abstract, a probability distribution that tells me how much of the content of the abstract is a Claim, how much is a Premise and how much is Not Relevant. I thought that maybe a solution similar to QuaNet, would be to predict, for each token in the abstract a class (so we are at the beginning solving a sequence tagging task), and then to use that prediction, concatenated with all the other predictions, to predict a probability distribution. So I imagine the loss to be a sum of the single losses of the sequence tagging task, plus the loss of the kldivergence (or the mae) of the predicted distribution and the true distribution

In [None]:
eval_loss, eval_metrics = evaluate(device, classify_and_count_model, test_ds, eval_batch_size = 32, metrics = dist_metrics,
                                           gen_preds_labels_fn  = get_quantization_preds_labels)

In [None]:
eval_loss

In [None]:
eval_metrics

In [88]:
# metrics = [torchmetrics.Accuracy(num_classes=3, average='macro', task = "multiclass"),
#      torchmetrics.F1Score(num_classes=3, average='macro', task = "multiclass")]
# history = train(device=device, train_dataset=train_ds, model=wrapper, eval_dataset=val_ds,
#                generate_preds_labels_fn=classification_preds_labels_fn, metrics=metrics)
# _ , test_metric = evaluate(device = device, model = model,eval_dataset= test_ds, eval_batch_size= 32,
# gen_preds_labels_fn=classification_preds_labels_fn)
losses = {"kldiv": kldiv_loss,
          "mae" : mae_loss}
quantify_models = []
for loss_name, loss in losses.items():
    num_quantify_classes = 3
    quantify_config = BertConfig.from_pretrained(
        model_name_or_path,
        num_labels=num_quantify_classes)
    quantify_model = BertForLabelDistribution.from_pretrained(model_name_or_path, config=quantify_config, loss_fn=loss, loss_fn_name =loss_name )
    quantify_model.to(device)
    quantify_models.append(quantify_model)

histories = {}
test_metrics = {}
for quantify_model in quantify_models:
    qua_metrics = [torchmetrics.KLDivergence(log_prob=False, reduction="mean").to(device),
                torchmetrics.MeanAbsoluteError().to(device),
                torchmetrics.MeanSquaredError().to(device)]
    history = train(device=device, train_dataset=train_ds, model=quantify_model, eval_dataset=val_ds,
                    generate_preds_labels_fn=get_quantization_preds_labels, metrics=qua_metrics, train_batch_size = 5)
    histories[quantify_model.name] = history
    test_metric = evaluate(device, quantify_model, test_ds, eval_batch_size=32, metrics=qua_metrics,
                            gen_preds_labels_fn=get_quantization_preds_labels)
    test_metrics[quantify_model.name] = test_metric

KeyboardInterrupt: 

In [40]:
classify_count_config = BertConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class_labels)
classify_count_config.max_length = max_seq_length
sequence_tag_quant_model = BertForSequenceTaggingAndQuantification.from_pretrained(
    model_name_or_path,
    config=classify_count_config,
    distribution_loss_fn = kldiv_loss,
    loss_fn_name = "kldiv"
)
sequence_tag_quant_model.to(device)
dist_metrics = [torchmetrics.KLDivergence(log_prob=False, reduction="mean").to(device),
                torchmetrics.MeanAbsoluteError().to(device),
                torchmetrics.MeanSquaredError().to(device)]
history = train(device=device, train_dataset=train_ds, model=sequence_tag_quant_model, eval_dataset=val_ds, generate_preds_labels_fn=get_quantization_preds_labels, metrics=dist_metrics)

Some weights of BertForSequenceTaggingAndQuantification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'projection_head.0.bias', 'projection_head.0.weight', 'projection_head.2.bias', 'projection_head.2.weight', 'rnn.bias_hh_l0', 'rnn.bias_hh_l0_reverse', 'rnn.bias_ih_l0', 'rnn.bias_ih_l0_reverse', 'rnn.weight_hh_l0', 'rnn.weight_hh_l0_reverse', 'rnn.weight_ih_l0', 'rnn.weight_ih_l0_reverse']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

TRAINING MODEL seq_tag_quant_model_kldiv



Iteration:   0%|          | 0/175 [00:06<?, ?it/s]
Epoch:   0%|          | 0/30 [00:06<?, ?it/s]


NameError: name 'distribution_loss_fn' is not defined