## Hierarchical classification using Local Classification per Parent Node technique
Here we use the same approach as in /solutions/custom_tokens/xlm_roberta_with_classification_start_span_token.ipynb to train the first classifier to predict the first class, and we train extra three classifiers for each of those classes that will provide us fine-grained classification.

In [1]:
from google.colab import drive
drive.mount('/content/drive')
sub1 = 'drive/My Drive/Colab Notebooks/semeval_data/subtask1.parquet'
print(sub1)

# from pathlib import Path
# wd = Path.cwd()
# wd = wd.parent.parent
# wd = wd / 'merged_data'
# sub1 = str(wd) + '/subtask1.parquet'
# print(sub1)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
drive/My Drive/Colab Notebooks/semeval_data/subtask1.parquet


In [2]:
import pandas as pd
df = pd.read_parquet(sub1)

In [3]:
import re
def labelNum(row):
    if row['class1'] == 'Antagonist':
        return int(0)
    if row['class1'] == 'Innocent':
        return int(1)
    if row['class1'] == 'Protagonist':
        return int(2)
def cleanText(row):
    text = str(row['text'])
    #text = re.sub(r'[^\w\s]', ' ', text)
    text = text.replace('\n',' ').replace('  ', ' ')
    return text
df['label1'] = df.apply(labelNum,axis=1)
df['input'] = df.apply(cleanText,axis=1)

In [4]:
def labelNum2(row):
    labels2 = [0 for _ in range(12)]
    if row['label1'] == 2:
        #labels2 = [0 for _ in range(6)]
        if 'Guardian' in row['classes2']:
            labels2[0] = 1
        if 'Martyr' in row['classes2']:
            labels2[1] = 1
        if 'Peacemaker' in row['classes2']:
            labels2[2] = 1
        if 'Rebel' in row['classes2']:
            labels2[3] = 1
        if 'Underdog' in row['classes2']:
            labels2[4] = 1
        if 'Virtuous' in row['classes2']:
            labels2[5] = 1
    elif row['label1'] == 0:
        #labels2 = [0 for _ in range(12)]
        if 'Instigator' in row['classes2']:
           labels2[0] = 1
        if 'Conspirator' in row['classes2']:
            labels2[1] = 1
        if 'Tyrant' in row['classes2']:
            labels2[2] = 1
        if  'Foreign Adversary' in row['classes2']:
            labels2[3] = 1
        if 'Traitor' in row['classes2']:
            labels2[4] = 1
        if 'Spy' in row['classes2']:
            labels2[5] = 1
        if 'Saboteur' in row['classes2']:
            labels2[6] = 1
        if 'Corrupt' in row['classes2']:
            labels2[7] = 1
        if 'Incompetent' in row['classes2']:
            labels2[8] = 1
        if 'Terrorist' in row['classes2']:
            labels2[9] = 1
        if 'Deceiver' in row['classes2']:
            labels2[10] = 1
        if 'Bigot' in row['classes2']:
            labels2[11] = 1
    elif row['label1'] == 1:
        #labels2 = [0 for _ in range(4)]
        if 'Forgotten' in row['classes2']:
            labels2[0] = 1
        if 'Exploited' in row['classes2']:
            labels2[1] = 1
        if 'Victim' in row['classes2']:
            labels2[2] = 1
        if 'Scapegoat' in row['classes2']:
            labels2[3] = 1
    return labels2

df['label2'] = df.apply(labelNum2, axis=1)

In [5]:
def find_all_substring_start_end(text, substring):
    # Use re.finditer to find all occurrences of the substring in the text
    matches = re.finditer(re.escape(substring), text)

    # Collect the start and end indices of all matches
    positions = [(match.start(), match.end()) for match in matches]

    return positions
def adjust_start_end(row):
    org_text,cl_text,start,end,entity = str(row['text']),str(row['input']),int(row['start']),int(row['end']),str(row['entity'])
    ss1 = find_all_substring_start_end(org_text,entity)
    ss2 = find_all_substring_start_end(cl_text,entity)
    #print(ss1,ss2)
    #print(row['text'][start:end])
    a = 0
    for i in range(len(ss1)):
        if abs((ss1[i][0] - start) + (ss1[i][1] - end) ) <= 2:
            a = i
            break
    if org_text[ss1[a][0]:ss1[a][1]] != cl_text[ss2[a][0]:ss2[a][1]]:
        print("ERROR!")
    return ss2[a][0],ss2[a][1]
df['new_start_end'] = df.apply(adjust_start_end,axis=1)
print(df.loc[0])

lang                                                            BG
art_name                                                BG_670.txt
entity                                                       Запад
start                                                          152
end                                                            156
class1                                                  Antagonist
classes2              [Conspirator, Instigator, Foreign Adversary]
text             Опитът на колективния Запад да „обезкърви Руси...
label1                                                           0
input            Опитът на колективния Запад да „обезкърви Руси...
label2                        [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
new_start_end                                           (151, 156)
Name: 0, dtype: object


In [6]:
def addTokensToInput(row):
    inp = row['input']
    start,end = row['new_start_end']
    #print(start,end)
    start = int(start)
    end = int(end)
    token_input = inp[:start] + "[SPAN_START] " + inp[start:end] + " [SPAN_END]" + inp[end:]
    return token_input

df['span_input'] = df.apply(addTokensToInput,axis=1)

In [7]:
def upStartEnd(row):
    start,end = row['new_start_end']
    start += len("[SPAN_START] ")
    end += len("[SPAN_START] ")
    return start,end

df['new_start_end'] = df.apply(upStartEnd,axis = 1)

In [8]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = XLMRobertaForSequenceClassification.from_pretrained("xlm-roberta-base", num_labels=3).to(device)
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base")

def preprocess_function(examples):
    return tokenizer(examples['span_input'], padding=True, truncation=True,max_length=8192,return_offsets_mapping=True)

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
extraTokens = {
    "additional_special_tokens": ["[SPAN_START]", "[SPAN_END]"]
}
num_added_toks = tokenizer.add_special_tokens(extraTokens)
model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(250004, 768, padding_idx=1)

In [10]:
data = df.loc[ : , ['span_input', 'label1', 'label2', 'new_start_end', 'entity']]

In [11]:
data

Unnamed: 0,span_input,label1,label2,new_start_end,entity
0,Опитът на колективния Запад да „обезкърви Руси...,0,"[1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]","(164, 169)",Запад
1,Опитът на колективния Запад да „обезкърви Руси...,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","(541, 544)",САЩ
2,Опитът на колективния Запад да „обезкърви Руси...,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","(546, 550)",НАТО
3,Опитът на колективния Запад да „обезкърви Руси...,0,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]","(589, 596)",Украйна
4,Опитът на колективния Запад да „обезкърви Руси...,1,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","(644, 661)",украински войници
...,...,...,...,...,...
2897,Медведев: Даже в случае признания поражения Ки...,2,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","(569, 576)",Россией
2898,Медведев: Даже в случае признания поражения Ки...,2,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","(1015, 1021)",Москва
2899,Медведев: Даже в случае признания поражения Ки...,0,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]","(1308, 1312)",НАТО
2900,Медведев: Даже в случае признания поражения [S...,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","(57, 63)",Киевом


In [12]:
data['tokenized']=data.apply(preprocess_function,axis=1)

In [13]:
def indexes(row):
    off_mask = row['tokenized']['offset_mapping']
    start,end = row['new_start_end'][0],row['new_start_end'][1]
    inds = list()
    for p in range(len(off_mask)):
        if off_mask[p][0] >= start and off_mask[p][1] <= end:
            if p != len(off_mask)-1:
                inds.append(p)
    #if len(inds) > 1:
        #print("GREATER THAN 1")
    if len(inds) == 0:
        print(start,end)
    return inds
data['indexes'] = data.apply(indexes,axis=1)

In [14]:
data['list'] = data['tokenized'].apply(lambda x: x['input_ids'])
data['attention'] = data['tokenized'].apply(lambda x: x['attention_mask'])
ids = data['list']
att = data['attention']
indexes = data['indexes']
tids = list()
tatt = list()
print(len(ids),len(att),len(indexes))
for i in range(len(ids)):
    tids.append(torch.tensor(ids[i]))
    tatt.append(torch.tensor(att[i]))

2902 2902 2902


In [15]:
sliced_ids = list()
sliced_ntids = list()
sliced_att = list()
key_inds = list()
key_ids = list()

def slices(index,size,context_size):
    if (size<context_size):
        return 0,size
    lower_c = int(context_size/2-1)
    upper_c = int(context_size/2)
    #print(lower_c,upper_c)
    if index < lower_c:
        return 0,context_size
    elif index >= lower_c:
        if index + upper_c > size:
            return index-(context_size-(size-index)), size
        else:
            return index-lower_c,index+upper_c+1


for i in range(len(tids)):
    slower,supper = slices(indexes[i][0],len(tids[i]),510)
    #key_tid = tids[i][indexes[i][0]]
    pid = ids[i][slower:supper]
    key_inds.append([])
    for j in indexes[i]:
        key_id = ids[i][j]
        if key_id not in pid:
           print(len(ids[i]),key_id,slower,supper,indexes[i])
        key_inds[i].append(pid.index(key_id))
    apid = tids[i][slower:supper]
    apatt = tatt[i][slower:supper]
    if 0 not in pid:
        apid = torch.cat((torch.tensor([0]),apid),dim=0)
        apatt = torch.cat((torch.tensor([1]),apatt),dim=0)
    if 2 not in pid:
        apid = torch.cat((apid,torch.tensor([2])),dim=0)
        apatt = torch.cat((apatt,torch.tensor([1])),dim=0)
    sliced_ids.append(apid)
    sliced_att.append(apatt)

Min = 10000
Max = 0
ind2 = 0
for i in range(len(indexes)):
    if len(sliced_ids[i]) < Min:
        Min = len(sliced_ids[i])
        ind2 = i

    if len(sliced_ids[i]) > Max:
        Max = len(sliced_ids[i])

In [16]:
input_ids = list()
att_mask = list()
for ten,att in zip(sliced_ids,sliced_att):
    if len(ten) < 512:
        padding_length = 512 - len(ten)
        padding_tensor = torch.full((padding_length,), tokenizer.pad_token_id, dtype=ten.dtype)
        padding_tensor2 = torch.full((padding_length,), 0, dtype=att.dtype)
        ten = torch.cat((ten,padding_tensor),dim=0)
        att = torch.cat((att,padding_tensor2),dim=0)
    input_ids.append(ten)
    att_mask.append(att)
inputIds = torch.stack(input_ids)
attMask = torch.stack(att_mask)

inputIds_np = inputIds.numpy()
attMask_np = attMask.numpy()
y1 = data['label1'].values
y2 = data['label2'].values

In [17]:
from sklearn.model_selection import train_test_split
X_train_ids, X_test_ids, X_train_mask, X_test_mask, y1_train, y1_test, y2_train, y2_test = train_test_split(
    inputIds_np, attMask_np, y1, y2, test_size=0.2, random_state=42, shuffle=True
)

In [18]:
import numpy as np
y2_train = np.array(y2_train.tolist(), dtype=np.int8)
y2_test = np.array(y2_test.tolist(), dtype=np.int8)

In [19]:
X_train_ids = torch.tensor(X_train_ids, dtype=torch.long).to(device)
X_test_ids = torch.tensor(X_test_ids, dtype=torch.long).to(device)
X_train_mask = torch.tensor(X_train_mask, dtype=torch.long).to(device)
X_test_mask = torch.tensor(X_test_mask, dtype=torch.long).to(device)
y1_train = torch.tensor(y1_train, dtype=torch.long).to(device)
y1_test = torch.tensor(y1_test, dtype=torch.long).to(device)
y2_train = torch.tensor(y2_train, dtype=torch.long).to(device)
y2_test = torch.tensor(y2_test, dtype=torch.long).to(device)

In [20]:
from torch.utils.data import DataLoader, TensorDataset

train_dataset = TensorDataset(X_train_ids, X_train_mask, y1_train, y2_train)
test_dataset = TensorDataset(X_test_ids, X_test_mask, y1_test, y2_test )

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True) #shuffle=True provides data shuffle for batches in different epochs
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [21]:
from torch.optim import AdamW
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import torch.nn as nn

#first layer classifier
classifier = nn.Linear(model.config.hidden_size, 3).to(device)
#optimizer = AdamW(list(classifier.parameters()) + list(model.parameters()), lr=8e-6)
optimizer = AdamW(model.parameters(), lr=8e-6)
criterion = nn.CrossEntropyLoss()

In [22]:
# not used?
# class FocalLoss(nn.Module):
#     def __init__(self, gamma=2., alpha=0.25, num_classes=3):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha
#         self.num_classes = num_classes
#         self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')

#     def forward(self, inputs, targets):
#         ce_loss = self.cross_entropy_loss(inputs, targets)
#         p_t = torch.exp(-ce_loss)  # Probability of correct class
#         focal_loss = self.alpha * (1 - p_t) ** self.gamma * ce_loss
#         return focal_loss.mean()

#criterion = FocalLoss(gamma=2., alpha=0.25)

In [29]:
criterion2 = nn.BCEWithLogitsLoss()

class SecondLayerClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_classes):
        """
        Multi-layer classifier for second-layer classification.

        Args:
            input_dim (int): Dimension of the input features.
            hidden_dims (list of int): List of hidden layer dimensions.
            num_classes (int): Number of output classes.
            dropout_prob (float): Dropout probability (default: 0.3). // this is removed
        """
        super(SecondLayerClassifier, self).__init__()

        self.num_classes = num_classes

        layers = []
        current_dim = input_dim

        # Add hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.GELU())
            #layers.append(nn.Dropout(dropout_prob)) #probably dont need drouput, since we use dropout for final probability, try to remove this
            current_dim = hidden_dim

        # Final output layer
        layers.append(nn.Linear(current_dim, num_classes))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [24]:
#klasifikatori za finiju granulaciju

hidden_dimension = []

child_classifiers = {
    int(0) : SecondLayerClassifier(input_dim=model.config.hidden_size, hidden_dims=hidden_dimension, num_classes=12).to(device) , #Antagonist
    int(1) : SecondLayerClassifier(input_dim=model.config.hidden_size, hidden_dims=hidden_dimension, num_classes=4).to(device) , #Innocent
    int(2) : SecondLayerClassifier(input_dim=model.config.hidden_size, hidden_dims=hidden_dimension,  num_classes=6).to(device) , #Protagonist
}

second_layer_optimizers = {
    name: AdamW(sec_layer_classifier.parameters(), lr=0.0001)
    for name, sec_layer_classifier in child_classifiers.items()
}

In [25]:
for cls, child_classifier in child_classifiers.items():
  print(cls)
  print(child_classifier)

0
SecondLayerClassifier(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=12, bias=True)
  )
)
1
SecondLayerClassifier(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=4, bias=True)
  )
)
2
SecondLayerClassifier(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=6, bias=True)
  )
)


In [26]:
import numpy as np
# for the confusion matrix in the end
all_preds = np.array([], dtype=np.int8)
all_labels = np.array([], dtype=np.int8)

In [27]:
from sklearn.metrics import precision_recall_fscore_support, multilabel_confusion_matrix, f1_score

second_layer_true_labels = {cls: [] for cls in child_classifiers.keys()}
second_layer_pred_labels = {cls: [] for cls in child_classifiers.keys()}
confusion_matrices = {cls: [] for cls in child_classifiers.keys()}

num_epochs = 6
debug = 0
for epoch in range(num_epochs):

    model.train()
    classifier.train()
    child_classifiers[0].train()
    child_classifiers[1].train()
    child_classifiers[2].train()

    total_loss = 0
    correct_predictions = 0
    total_predictions = 0

    second_layer_stats = {
        cls: {'loss': 0.0, 'correct': 0, 'total': 0}
        for cls in child_classifiers.keys()
    }

    # reset for cm metrics
    for cls in child_classifiers.keys():
        second_layer_true_labels[cls] = []
        second_layer_pred_labels[cls] = []
        confusion_matrices[cls] = []

    train_progress_bar = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{num_epochs}")

    for batch in train_progress_bar:
        optimizer.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels_1 = batch[2].to(device)
        labels_2 = batch[3].to(device)  # second-layer labels

        batch_size = input_ids.size(0)

        #taking the output from BERT model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_1, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
        # finding the embedding that "represents" start_span special token
        span_start_token_id = tokenizer.convert_tokens_to_ids('[SPAN_START]')
        start_mask = (input_ids == span_start_token_id)
        entity_representations = []
        start_indices = start_mask.nonzero(as_tuple=True)[1]
        valid_spans = (start_indices != -1)
        valid_start_indices = start_indices[valid_spans]

        for i in range(batch_size):
            entity_tokens = hidden_states[i, valid_start_indices[i]] #for this version only start span tokens
            entity_representations.append(entity_tokens)

        entity_representations = torch.stack(entity_representations, dim=0)

        #first layer classification
        logits = classifier(entity_representations)
        loss = criterion(logits, labels_1)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=-1)
        correct_predictions += (preds == labels_1).sum().item()
        total_predictions += labels_1.size(0)

        train_progress_bar.set_postfix({'loss': loss.item()})

        #second layer classification
        for cls, second_layer_classifier in child_classifiers.items():
            second_layer_indices = (labels_1 == cls).nonzero(as_tuple=True)[0]

            if second_layer_indices.size(0) > 0:
                second_layer_inputs = entity_representations[second_layer_indices].detach()

                num_classes = second_layer_classifier.num_classes
                second_layer_labels = labels_2[second_layer_indices, :num_classes].float()

                child_logits = second_layer_classifier(second_layer_inputs)
                child_loss = criterion2(child_logits, second_layer_labels)

                second_layer_optimizers[cls].zero_grad()
                child_loss.backward()
                second_layer_optimizers[cls].step()

                #for accuracy
                second_layer_stats[cls]['loss'] += child_loss.item()
                child_preds = (torch.sigmoid(child_logits) > 0.35).int()
                correct = ((child_preds == second_layer_labels.int()).all(dim=1)).sum().item() #strict accuracy, all the labels have to be well predicted
                second_layer_stats[cls]['correct'] += correct
                second_layer_stats[cls]['total'] += second_layer_labels.size(0)
                #for confusion matrix
                child_preds = child_preds.cpu().numpy()
                second_layer_labels = second_layer_labels.cpu().numpy()
                second_layer_true_labels[cls].append(second_layer_labels)
                second_layer_pred_labels[cls].append(child_preds)
                confusion_matrices[cls].append(multilabel_confusion_matrix(second_layer_labels, child_preds))

    avg_train_loss = total_loss / len(train_dataloader)
    train_accuracy = correct_predictions / total_predictions

    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Training loss: {avg_train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")

    for cls, stats in second_layer_stats.items():
        if stats['total'] > 0:
            avg_loss = stats['loss'] / stats['total']
            accuracy = stats['correct'] / stats['total']
        else:
            avg_loss = 0.0
            accuracy = 0.0
        print(f"Second-Layer Classifier: {cls}, Avg Loss: {avg_loss:.4f}, (strict) Accuracy: {accuracy:.4f}, Total: {stats['total']}")

    for cls in child_classifiers.keys():
        true_labels = np.vstack(second_layer_true_labels[cls])
        pred_labels = np.vstack(second_layer_pred_labels[cls])

        # Compute precision, recall, F1 for the current classifier
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='micro')
        macro_f1 = f1_score(true_labels, pred_labels, average='macro')
        print(f"Second-Layer Classifier {cls} - Precision: {precision:.4f}, Recall: {recall:.4f}, Micro-F1-Score: {f1:.4f}, Macro-F1-Score: {macro_f1:.4f}")

        # Compute confusion matrix for the current classifier
        confusion_matrix = np.sum(confusion_matrices[cls], axis=0)
        print(f"Confusion Matrix for Classifier {cls}:\n{confusion_matrix}")

    model.eval()
    classifier.eval()
    child_classifiers[0].eval()
    child_classifiers[1].eval()
    child_classifiers[2].eval()

    test_loss = 0
    correct_test_predictions = 0
    total_test_predictions = 0

    test_stats_per_classifier = {
        cls: {'loss': 0.0, 'correct': 0, 'total': 0}
        for cls in child_classifiers.keys()
    }
    overall_stats = {'loss': 0.0, 'correct': 0, 'total': 0}

    test_progress_bar = tqdm(test_dataloader, desc=f"Test Epoch {epoch + 1}/{num_epochs}")

    with torch.no_grad():
        for batch in test_progress_bar:

            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            labels_1 = batch[2].to(device)
            labels_2 = batch[3].to(device)

            batch_size = input_ids.size(0)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_1, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]

            span_start_token_id = tokenizer.convert_tokens_to_ids('[SPAN_START]')
            start_mask = (input_ids == span_start_token_id)
            entity_representations = []
            start_indices = start_mask.nonzero(as_tuple=True)[1]
            valid_spans = (start_indices != -1)
            valid_start_indices = start_indices[valid_spans]

            # extract entity tokens for every sample in batch
            for i in range(batch_size):
                entity_tokens = hidden_states[i, valid_start_indices[i]]
                entity_representations.append(entity_tokens)

            entity_representations = torch.stack(entity_representations, dim=0)

            logits = classifier(entity_representations)
            loss = criterion(logits, labels_1)
            test_loss += loss.item()

            preds = torch.argmax(logits, dim=-1)

            #if epoch is the last epoch we want to redirect data to second layer classifier according to predicted label
            if epoch == num_epochs-1:

                #second layer classification
                for cls, second_layer_classifier in child_classifiers.items():
                    incorrect_label_1_indices = ((preds != labels_1) & (preds == cls)).nonzero(as_tuple=True)[0]
                    if incorrect_label_1_indices.size(0) > 0:
                        overall_stats['total'] += incorrect_label_1_indices.size(0)
                        overall_stats['correct'] += 0

                    second_layer_indices = ((preds == labels_1) & (labels_1 == cls)).nonzero(as_tuple=True)[0]

                    # only correct first layer predictions for second layer evaluation
                    if second_layer_indices.size(0) > 0:
                        second_layer_inputs = entity_representations[second_layer_indices]

                        num_classes = second_layer_classifier.num_classes
                        second_layer_labels = labels_2[second_layer_indices, :num_classes].float()

                        child_logits = second_layer_classifier(second_layer_inputs)
                        child_loss = criterion2(child_logits, second_layer_labels)

                        test_stats_per_classifier[cls]['loss'] += child_loss.item()

                        child_preds = (torch.sigmoid(child_logits) > 0.35).int()

                        # Count strict accuracy (all labels must match)
                        correct = ((child_preds == second_layer_labels.int()).all(dim=1)).sum().item()
                        test_stats_per_classifier[cls]['correct'] += correct
                        test_stats_per_classifier[cls]['total'] += second_layer_labels.size(0)

                        overall_stats['loss'] += child_loss.item() * second_layer_labels.size(0)
                        overall_stats['correct'] += correct
                        overall_stats['total'] += second_layer_labels.size(0)

                #for confusion matrix in determing first class
                all_preds = np.concatenate((all_preds, preds.cpu().numpy()))
                all_labels = np.concatenate((all_labels, labels_1.cpu().numpy()))

            correct_test_predictions += (preds == labels_1).sum().item()
            total_test_predictions += labels_1.size(0)

            test_progress_bar.set_postfix({'loss': loss.item()})

    avg_test_loss = test_loss / len(test_dataloader)
    test_accuracy = correct_test_predictions / total_test_predictions

    print(f"Test loss: {avg_test_loss:.4f}, Test accuracy: {test_accuracy:.4f}")

    if epoch == num_epochs-1:

        for cls, stats in test_stats_per_classifier.items():
            if stats['total'] > 0:
                avg_loss = stats['loss'] / stats['total']
                accuracy = stats['correct'] / stats['total']
            else:
                avg_loss = 0.0
                accuracy = 0.0
            print(f"Second-Layer Classifier: {cls}, Avg Loss: {avg_loss:.4f}, (strict) Accuracy: {accuracy:.4f}, Total: {stats['total']}")

        overall_accuracy = overall_stats['correct'] / overall_stats['total'] if overall_stats['total'] > 0 else 0.0
        print(f"OVERALL STRICT ACCURACY AFTER THE SECOND LAYER CLASSIFICATION: {overall_accuracy:.4f}")
        #strict accuracy means all the labels are correctly predicted for the input

Training Epoch 1/6: 100%|██████████| 146/146 [03:50<00:00,  1.58s/it, loss=1.44]


Epoch 1/6
Training loss: 0.8660, Training accuracy: 0.6114
Second-Layer Classifier: 0, Avg Loss: 0.0582, (strict) Accuracy: 0.0646, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.1387, (strict) Accuracy: 0.1831, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.1045, (strict) Accuracy: 0.1063, Total: 668
Second-Layer Classifier 0 - Precision: 0.1109, Recall: 0.4439, Micro-F1-Score: 0.1775, Macro-F1-Score: 0.1355
Confusion Matrix for Classifier 0:
[[[ 828  113]
  [ 186   18]]

 [[ 626  417]
  [  60   42]]

 [[ 700  342]
  [  72   31]]

 [[ 257  546]
  [  79  263]]

 [[ 607  506]
  [  19   13]]

 [[ 775  356]
  [  10    4]]

 [[1000  100]
  [  40    5]]

 [[ 732  341]
  [  47   25]]

 [[ 614  389]
  [  78   64]]

 [[ 477  556]
  [  51   61]]

 [[ 577  473]
  [  55   40]]

 [[ 695  429]
  [  17    4]]]
Second-Layer Classifier 1 - Precision: 0.3782, Recall: 0.9094, Micro-F1-Score: 0.5342, Macro-F1-Score: 0.3046
Confusion Matrix for Classifier 1:
[[[334 152]
  [ 16   6]]

 [[103 355]


Test Epoch 1/6: 100%|██████████| 37/37 [00:17<00:00,  2.16it/s, loss=0.427]


Test loss: 0.5969, Test accuracy: 0.7814


Training Epoch 2/6: 100%|██████████| 146/146 [03:50<00:00,  1.58s/it, loss=0.399]


Epoch 2/6
Training loss: 0.5534, Training accuracy: 0.7897
Second-Layer Classifier: 0, Avg Loss: 0.0359, (strict) Accuracy: 0.1057, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.0746, (strict) Accuracy: 0.8071, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.0837, (strict) Accuracy: 0.2530, Total: 668
Second-Layer Classifier 0 - Precision: 0.4116, Recall: 0.1160, Micro-F1-Score: 0.1810, Macro-F1-Score: 0.0523
Confusion Matrix for Classifier 0:
[[[ 933    8]
  [ 201    3]]

 [[1040    3]
  [ 102    0]]

 [[1040    2]
  [ 103    0]]

 [[ 641  162]
  [ 205  137]]

 [[1111    2]
  [  31    1]]

 [[1131    0]
  [  14    0]]

 [[1099    1]
  [  45    0]]

 [[1073    0]
  [  72    0]]

 [[1000    3]
  [ 141    1]]

 [[1004   29]
  [ 106    6]]

 [[1047    3]
  [  94    1]]

 [[1124    0]
  [  21    0]]]
Second-Layer Classifier 1 - Precision: 0.8265, Recall: 0.8170, Micro-F1-Score: 0.8217, Macro-F1-Score: 0.2275
Confusion Matrix for Classifier 1:
[[[486   0]
  [ 22   0]]

 [[453   5]


Test Epoch 2/6: 100%|██████████| 37/37 [00:17<00:00,  2.17it/s, loss=0.327]


Test loss: 0.4761, Test accuracy: 0.8382


Training Epoch 3/6: 100%|██████████| 146/146 [03:51<00:00,  1.58s/it, loss=0.0941]


Epoch 3/6
Training loss: 0.3883, Training accuracy: 0.8587
Second-Layer Classifier: 0, Avg Loss: 0.0342, (strict) Accuracy: 0.1546, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.0737, (strict) Accuracy: 0.8130, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.0822, (strict) Accuracy: 0.2246, Total: 668
Second-Layer Classifier 0 - Precision: 0.5023, Recall: 0.1737, Micro-F1-Score: 0.2581, Macro-F1-Score: 0.0627
Confusion Matrix for Classifier 0:
[[[ 917   24]
  [ 192   12]]

 [[1042    1]
  [ 102    0]]

 [[1040    2]
  [ 103    0]]

 [[ 628  175]
  [ 136  206]]

 [[1112    1]
  [  32    0]]

 [[1131    0]
  [  14    0]]

 [[1100    0]
  [  45    0]]

 [[1073    0]
  [  72    0]]

 [[1000    3]
  [ 142    0]]

 [[1019   14]
  [ 108    4]]

 [[1049    1]
  [  94    1]]

 [[1124    0]
  [  21    0]]]
Second-Layer Classifier 1 - Precision: 0.8330, Recall: 0.8170, Micro-F1-Score: 0.8249, Macro-F1-Score: 0.2275
Confusion Matrix for Classifier 1:
[[[486   0]
  [ 22   0]]

 [[457   1]


Test Epoch 3/6: 100%|██████████| 37/37 [00:17<00:00,  2.16it/s, loss=0.315]


Test loss: 0.4768, Test accuracy: 0.8158


Training Epoch 4/6: 100%|██████████| 146/146 [03:50<00:00,  1.58s/it, loss=0.0163]


Epoch 4/6
Training loss: 0.2763, Training accuracy: 0.9121
Second-Layer Classifier: 0, Avg Loss: 0.0334, (strict) Accuracy: 0.1659, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.0750, (strict) Accuracy: 0.8110, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.0811, (strict) Accuracy: 0.2725, Total: 668
Second-Layer Classifier 0 - Precision: 0.5178, Recall: 0.1815, Micro-F1-Score: 0.2687, Macro-F1-Score: 0.0625
Confusion Matrix for Classifier 0:
[[[ 928   13]
  [ 201    3]]

 [[1040    3]
  [ 101    1]]

 [[1034    8]
  [ 103    0]]

 [[ 630  173]
  [ 119  223]]

 [[1110    3]
  [  32    0]]

 [[1131    0]
  [  14    0]]

 [[1100    0]
  [  45    0]]

 [[1073    0]
  [  72    0]]

 [[ 999    4]
  [ 140    2]]

 [[1024    9]
  [ 110    2]]

 [[1046    4]
  [  93    2]]

 [[1124    0]
  [  21    0]]]
Second-Layer Classifier 1 - Precision: 0.8252, Recall: 0.8189, Micro-F1-Score: 0.8221, Macro-F1-Score: 0.2369
Confusion Matrix for Classifier 1:
[[[484   2]
  [ 22   0]]

 [[456   2]


Test Epoch 4/6: 100%|██████████| 37/37 [00:17<00:00,  2.16it/s, loss=0.373]


Test loss: 0.4700, Test accuracy: 0.8296


Training Epoch 5/6: 100%|██████████| 146/146 [03:50<00:00,  1.58s/it, loss=0.0053]


Epoch 5/6
Training loss: 0.1797, Training accuracy: 0.9405
Second-Layer Classifier: 0, Avg Loss: 0.0330, (strict) Accuracy: 0.1686, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.0689, (strict) Accuracy: 0.8110, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.0808, (strict) Accuracy: 0.2575, Total: 668
Second-Layer Classifier 0 - Precision: 0.5436, Recall: 0.1846, Micro-F1-Score: 0.2756, Macro-F1-Score: 0.0608
Confusion Matrix for Classifier 0:
[[[ 927   14]
  [ 200    4]]

 [[1042    1]
  [ 102    0]]

 [[1041    1]
  [ 103    0]]

 [[ 630  173]
  [ 114  228]]

 [[1113    0]
  [  32    0]]

 [[1131    0]
  [  14    0]]

 [[1100    0]
  [  45    0]]

 [[1073    0]
  [  72    0]]

 [[1001    2]
  [ 141    1]]

 [[1028    5]
  [ 108    4]]

 [[1047    3]
  [  95    0]]

 [[1124    0]
  [  21    0]]]
Second-Layer Classifier 1 - Precision: 0.8288, Recall: 0.8208, Micro-F1-Score: 0.8248, Macro-F1-Score: 0.2453
Confusion Matrix for Classifier 1:
[[[486   0]
  [ 22   0]]

 [[454   4]


Test Epoch 5/6: 100%|██████████| 37/37 [00:17<00:00,  2.17it/s, loss=0.4]


Test loss: 0.4847, Test accuracy: 0.8485


Training Epoch 6/6: 100%|██████████| 146/146 [03:50<00:00,  1.58s/it, loss=0.00478]


Epoch 6/6
Training loss: 0.1270, Training accuracy: 0.9595
Second-Layer Classifier: 0, Avg Loss: 0.0322, (strict) Accuracy: 0.1712, Total: 1145
Second-Layer Classifier: 1, Avg Loss: 0.0698, (strict) Accuracy: 0.8110, Total: 508
Second-Layer Classifier: 2, Avg Loss: 0.0782, (strict) Accuracy: 0.2769, Total: 668
Second-Layer Classifier 0 - Precision: 0.5614, Recall: 0.1815, Micro-F1-Score: 0.2743, Macro-F1-Score: 0.0783
Confusion Matrix for Classifier 0:
[[[ 933    8]
  [ 200    4]]

 [[1042    1]
  [ 102    0]]

 [[1040    2]
  [ 103    0]]

 [[ 648  155]
  [ 135  207]]

 [[1112    1]
  [  32    0]]

 [[1131    0]
  [  14    0]]

 [[1100    0]
  [  45    0]]

 [[1073    0]
  [  72    0]]

 [[ 997    6]
  [ 131   11]]

 [[1027    6]
  [ 102   10]]

 [[1047    3]
  [  94    1]]

 [[1124    0]
  [  21    0]]]
Second-Layer Classifier 1 - Precision: 0.8243, Recall: 0.8227, Micro-F1-Score: 0.8235, Macro-F1-Score: 0.2724
Confusion Matrix for Classifier 1:
[[[486   0]
  [ 22   0]]

 [[451   7]


Test Epoch 6/6: 100%|██████████| 37/37 [00:17<00:00,  2.17it/s, loss=0.518]

Test loss: 0.5277, Test accuracy: 0.8485
Second-Layer Classifier: 0, Avg Loss: 0.0395, (strict) Accuracy: 0.1667, Total: 240
Second-Layer Classifier: 1, Avg Loss: 0.0792, (strict) Accuracy: 0.8416, Total: 101
Second-Layer Classifier: 2, Avg Loss: 0.0853, (strict) Accuracy: 0.2566, Total: 152
OVERALL STRICT ACCURACY AFTER THE SECOND LAYER CLASSIFICATION: 0.2823





In [28]:

# from sklearn.metrics import confusion_matrix
# import seaborn as sns
# import matplotlib.pyplot as plt

# cm = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])

# plt.figure(figsize=(8, 6))
# sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Antagonist', 'Innocent', 'Protagonist'], yticklabels=['Antagonist', 'Innocent', 'Protagonist'])
# plt.xlabel('Predicted Labels')
# plt.ylabel('True Labels')
# plt.title('Confusion Matrix')
# plt.show()