In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
import auxiliary as aux
from tqdm import tqdm
import traceback

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
from auxiliary import ensembler, json_to_Dataset_ensemble
from tqdm import tqdm

class KingBert(nn.Module):
    def __init__(self, distilbert_tuned, albert_tuned):
        super().__init__()
        self.distilbert = distilbert_tuned
        self.albert = albert_tuned

        for distilbert_param in self.distilbert.parameters():
            distilbert_param.requires_grad = False

        for albert_param in self.albert.parameters():
            albert_param.requires_grad = False 
        
        self.alpha = nn.Parameter(0.5 * torch.ones(47), requires_grad=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids):
        distilbert_output = self.distilbert(input_ids=distilbert_input_ids, attention_mask=distil_attention_mask)
        albert_output = self.albert(input_ids=albert_input_ids, attention_mask=alb_attention_mask)
        distilbert_fixed, albert_fixed = aux.ensembler(distilbert_output['logits'].squeeze(), albert_output['logits'].squeeze(), distilbert_word_ids.squeeze(), albert_word_ids.squeeze())

        distilbert_fixed = self.softmax(distilbert_fixed)
        albert_fixed = self.softmax(albert_fixed)

        final_output = distilbert_fixed * self.alpha + albert_fixed * (torch.ones(47) - self.alpha)

        return self.softmax(final_output)

train_dataset = aux.json_to_Dataset_ensemble('datasets/ensemble_train.json')

distilbert_tuned = AutoModelForTokenClassification.from_pretrained('models/distilbert1')
albert_tuned = AutoModelForTokenClassification.from_pretrained('models/albert1')

kingbert_model = KingBert(distilbert_tuned, albert_tuned)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kingbert_model.parameters(), lr=2e-5)


num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for i in tqdm(range(len(train_dataset)), desc="Steps in epoch"):
        try:

            item = train_dataset[i]
            
            distilbert_input_ids = torch.tensor(item['distilbert_inputids']).unsqueeze(0)
            albert_input_ids = torch.tensor(item['albert_inputids']).unsqueeze(0)
            distil_attention_mask = torch.tensor(item['distilbert_attention_masks']).unsqueeze(0)
            alb_attention_mask = torch.tensor(item['albert_attention_masks']).unsqueeze(0)
            distilbert_word_ids = torch.tensor([-100] + item['distilbert_wordids'][1:-1] + [-100]).unsqueeze(0)
            albert_word_ids = torch.tensor([-100] + item['albert_wordids'][1:-1] + [-100]).unsqueeze(0)
            targets = torch.tensor(item['spacy_labels']).unsqueeze(0)
            
            optimizer.zero_grad()
            
            output = kingbert_model(distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids)
            
            ohe_targets = torch.zeros(output.shape[0], output.shape[1])
            for i,j in enumerate(targets):
                ohe_targets[i][j] = 1

            loss = criterion(output, ohe_targets)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        except:
            continue
        
    avg_loss = total_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

torch.save(kingbert_model.state_dict(), 'model_state.pth')

print('Training complete.')


  stacked_tensors1 = torch.stack([torch.tensor(i) for i in output1])
  stacked_tensors2 = torch.stack([torch.tensor(i) for i in output2])
Steps in epoch:   3%|▎         | 539/18651 [00:49<27:32, 10.96it/s]

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForTokenClassification
from tqdm import tqdm
import auxiliary as aux

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset loading
train_dataset = aux.json_to_Dataset_ensemble('datasets/ensemble_train.json')

# Load fine-tuned base models
distilbert_tuned = AutoModelForTokenClassification.from_pretrained('models/distilbert1').to(device)
albert_tuned = AutoModelForTokenClassification.from_pretrained('models/albert1').to(device)

# Ensembler model
class KingBert(nn.Module):
    def __init__(self, distilbert_tuned, albert_tuned):
        super().__init__()
        self.distilbert = distilbert_tuned
        self.albert = albert_tuned

        for p in self.distilbert.parameters():
            p.requires_grad = False
        for p in self.albert.parameters():
            p.requires_grad = False

        self.alpha = nn.Parameter(0.5 * torch.ones(47, device=device), requires_grad=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, distilbert_input_ids, albert_input_ids,
                distil_attention_mask, alb_attention_mask,
                distilbert_word_ids, albert_word_ids):

        distilbert_output = self.distilbert(input_ids=distilbert_input_ids, attention_mask=distil_attention_mask)
        albert_output = self.albert(input_ids=albert_input_ids, attention_mask=alb_attention_mask)

        distilbert_logits = distilbert_output['logits'].squeeze()
        albert_logits = albert_output['logits'].squeeze()
        distilbert_word_ids = distilbert_word_ids.squeeze()
        albert_word_ids = albert_word_ids.squeeze()

        distilbert_fixed, albert_fixed = aux.ensembler(
            distilbert_logits, albert_logits,
            distilbert_word_ids, albert_word_ids
        )

        distilbert_fixed = distilbert_fixed.to(self.alpha.device)
        albert_fixed = albert_fixed.to(self.alpha.device)

        distilbert_fixed = self.softmax(distilbert_fixed)
        albert_fixed = self.softmax(albert_fixed)

        alpha = self.alpha
        one_minus_alpha = (1.0 - alpha).to(alpha.device)  # ensure same device

        combined = distilbert_fixed * alpha + albert_fixed * one_minus_alpha
        return self.softmax(combined)

# Instantiate model
kingbert_model = KingBert(distilbert_tuned, albert_tuned).to(device)

# Optimizer and loss
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([kingbert_model.alpha], lr=2e-5)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    total_loss = 0
    kingbert_model.train()

    for i in tqdm(range(len(train_dataset)), desc=f"Epoch {epoch+1:02d}"):
        try:
            item = train_dataset[i]

            distilbert_input_ids = torch.tensor(item['distilbert_inputids']).unsqueeze(0).to(device)
            albert_input_ids = torch.tensor(item['albert_inputids']).unsqueeze(0).to(device)
            distil_attention_mask = torch.tensor(item['distilbert_attention_masks']).unsqueeze(0).to(device)
            alb_attention_mask = torch.tensor(item['albert_attention_masks']).unsqueeze(0).to(device)
            distilbert_word_ids = torch.tensor([-100] + item['distilbert_wordids'][1:-1] + [-100]).unsqueeze(0).to(device)
            albert_word_ids = torch.tensor([-100] + item['albert_wordids'][1:-1] + [-100]).unsqueeze(0).to(device)
            labels = torch.tensor(item['spacy_labels']).to(device)

            optimizer.zero_grad()
            output = kingbert_model(distilbert_input_ids, albert_input_ids,
                                    distil_attention_mask, alb_attention_mask,
                                    distilbert_word_ids, albert_word_ids)

            if output.size(0) != labels.size(0):
                continue  # skip if shape mismatch due to token alignment

            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        except Exception as e:
            print(f"Skipped example {i} due to error: {e}")
            continue

    avg_loss = total_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Save model
torch.save(kingbert_model.state_dict(), 'model_state.pth')
print('Training complete.')


Using device: cuda


Epoch 01:  11%|█         | 2012/18651 [01:29<12:43, 21.80it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 01:  34%|███▎      | 6266/18651 [04:47<11:07, 18.57it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 01:  57%|█████▋    | 10645/18651 [08:05<07:17, 18.28it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 01:  82%|████████▏ | 15368/18651 [11:45<02:24, 22.69it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 01:  94%|█████████▍| 17593/18651 [13:26<00:49, 21.27it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 01: 100%|██████████| 18651/18651 [14:15<00:00, 21.81it/s]


Epoch [1/20], Loss: 3.0585


Epoch 02:  11%|█         | 2009/18651 [01:29<14:46, 18.78it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 02:  34%|███▎      | 6267/18651 [04:47<10:35, 19.48it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 02:  57%|█████▋    | 10645/18651 [08:11<07:50, 17.03it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 02:  82%|████████▏ | 15368/18651 [11:48<02:14, 24.34it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 02:  94%|█████████▍| 17592/18651 [13:30<00:56, 18.90it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 02: 100%|██████████| 18651/18651 [14:22<00:00, 21.63it/s]


Epoch [2/20], Loss: 3.0585


Epoch 03:  11%|█         | 2011/18651 [01:35<13:29, 20.55it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 03:  34%|███▎      | 6267/18651 [04:56<11:33, 17.86it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 03:  57%|█████▋    | 10646/18651 [08:27<07:15, 18.40it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 03:  82%|████████▏ | 15366/18651 [12:15<02:31, 21.73it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 03:  94%|█████████▍| 17593/18651 [14:02<00:53, 19.77it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 03: 100%|██████████| 18651/18651 [14:51<00:00, 20.91it/s]


Epoch [3/20], Loss: 3.0585


Epoch 04:  11%|█         | 2011/18651 [01:34<13:33, 20.46it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 04:  34%|███▎      | 6266/18651 [05:00<12:35, 16.39it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 04:  57%|█████▋    | 10645/18651 [08:28<07:59, 16.71it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 04:  82%|████████▏ | 15368/18651 [12:16<02:23, 22.89it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 04:  94%|█████████▍| 17593/18651 [14:04<00:55, 19.02it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 04: 100%|██████████| 18651/18651 [14:56<00:00, 20.80it/s]


Epoch [4/20], Loss: 3.0585


Epoch 05:  11%|█         | 2011/18651 [01:37<13:35, 20.40it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 05:  34%|███▎      | 6266/18651 [05:01<11:32, 17.89it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 05:  57%|█████▋    | 10645/18651 [08:33<07:16, 18.32it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 05:  82%|████████▏ | 15368/18651 [12:24<02:23, 22.86it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 05:  94%|█████████▍| 17592/18651 [14:10<00:51, 20.38it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 05: 100%|██████████| 18651/18651 [14:59<00:00, 20.73it/s]


Epoch [5/20], Loss: 3.0584


Epoch 06:  11%|█         | 2012/18651 [01:35<13:30, 20.52it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 06:  34%|███▎      | 6266/18651 [04:59<12:08, 17.01it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 06:  57%|█████▋    | 10646/18651 [08:28<08:23, 15.89it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 06:  82%|████████▏ | 15366/18651 [12:16<02:34, 21.27it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 06:  94%|█████████▍| 17593/18651 [14:02<00:57, 18.51it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 06: 100%|██████████| 18651/18651 [14:55<00:00, 20.84it/s]


Epoch [6/20], Loss: 3.0584


Epoch 07:  11%|█         | 2011/18651 [01:35<13:31, 20.52it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 07:  34%|███▎      | 6266/18651 [04:58<11:22, 18.14it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 07:  57%|█████▋    | 10645/18651 [08:31<07:19, 18.22it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 07:  82%|████████▏ | 15366/18651 [12:19<02:34, 21.28it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 07:  94%|█████████▍| 17592/18651 [14:05<00:52, 20.02it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 07: 100%|██████████| 18651/18651 [14:54<00:00, 20.84it/s]


Epoch [7/20], Loss: 3.0584


Epoch 08:  11%|█         | 2010/18651 [01:37<13:30, 20.54it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 08:  34%|███▎      | 6266/18651 [05:02<12:11, 16.93it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 08:  57%|█████▋    | 10645/18651 [08:31<09:08, 14.60it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 08:  82%|████████▏ | 15366/18651 [12:22<02:30, 21.84it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 08:  94%|█████████▍| 17591/18651 [14:09<00:54, 19.57it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 08: 100%|██████████| 18651/18651 [15:01<00:00, 20.68it/s]


Epoch [8/20], Loss: 3.0584


Epoch 09:  11%|█         | 2011/18651 [01:36<13:57, 19.88it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 09:  34%|███▎      | 6267/18651 [04:58<11:21, 18.18it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 09:  57%|█████▋    | 10645/18651 [08:30<07:22, 18.10it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 09:  82%|████████▏ | 15368/18651 [12:19<02:33, 21.35it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 09:  94%|█████████▍| 17592/18651 [14:04<00:50, 20.98it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 09: 100%|██████████| 18651/18651 [14:54<00:00, 20.85it/s]


Epoch [9/20], Loss: 3.0584


Epoch 10:  11%|█         | 2009/18651 [01:36<14:10, 19.58it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 10:  34%|███▎      | 6266/18651 [05:00<11:45, 17.57it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 10:  57%|█████▋    | 10645/18651 [08:30<07:35, 17.58it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 10:  82%|████████▏ | 15368/18651 [12:16<02:26, 22.46it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 10:  94%|█████████▍| 17593/18651 [14:02<00:53, 19.79it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 10: 100%|██████████| 18651/18651 [14:54<00:00, 20.84it/s]


Epoch [10/20], Loss: 3.0584


Epoch 11:  11%|█         | 2011/18651 [01:37<13:30, 20.53it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 11:  34%|███▎      | 6266/18651 [05:02<11:20, 18.21it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 11:  57%|█████▋    | 10645/18651 [08:35<07:00, 19.04it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 11:  82%|████████▏ | 15368/18651 [12:23<02:33, 21.46it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 11:  94%|█████████▍| 17592/18651 [14:10<00:55, 19.12it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 11: 100%|██████████| 18651/18651 [15:00<00:00, 20.72it/s]


Epoch [11/20], Loss: 3.0584


Epoch 12:  11%|█         | 2011/18651 [01:38<14:03, 19.72it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 12:  34%|███▎      | 6266/18651 [05:04<12:24, 16.64it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 12:  57%|█████▋    | 10645/18651 [08:32<08:19, 16.04it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 12:  82%|████████▏ | 15366/18651 [12:20<02:32, 21.49it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 12:  94%|█████████▍| 17593/18651 [14:07<00:53, 19.68it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 12: 100%|██████████| 18651/18651 [14:59<00:00, 20.74it/s]


Epoch [12/20], Loss: 3.0584


Epoch 13:  11%|█         | 2012/18651 [01:36<14:09, 19.59it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 13:  34%|███▎      | 6267/18651 [05:02<11:15, 18.33it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 13:  57%|█████▋    | 10646/18651 [08:37<07:37, 17.48it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 13:  82%|████████▏ | 15366/18651 [12:25<02:38, 20.72it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 13:  94%|█████████▍| 17592/18651 [14:12<00:52, 20.05it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 13: 100%|██████████| 18651/18651 [15:02<00:00, 20.67it/s]


Epoch [13/20], Loss: 3.0583


Epoch 14:  11%|█         | 2012/18651 [01:36<12:56, 21.43it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 14:  34%|███▎      | 6267/18651 [05:02<11:39, 17.70it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 14:  57%|█████▋    | 10645/18651 [08:29<07:51, 16.97it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 14:  82%|████████▏ | 15369/18651 [12:17<02:29, 22.00it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 14:  94%|█████████▍| 17593/18651 [14:03<00:53, 19.72it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 14: 100%|██████████| 18651/18651 [14:53<00:00, 20.86it/s]


Epoch [14/20], Loss: 3.0583


Epoch 15:  11%|█         | 2009/18651 [01:36<14:26, 19.21it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 15:  34%|███▎      | 6266/18651 [04:59<11:49, 17.45it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 15:  57%|█████▋    | 10645/18651 [08:29<07:48, 17.09it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 15:  82%|████████▏ | 15366/18651 [12:14<02:45, 19.84it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 15:  94%|█████████▍| 17593/18651 [14:01<00:56, 18.74it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 15: 100%|██████████| 18651/18651 [14:52<00:00, 20.90it/s]


Epoch [15/20], Loss: 3.0583


Epoch 16:  11%|█         | 2009/18651 [01:35<15:34, 17.80it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 16:  34%|███▎      | 6266/18651 [05:01<12:07, 17.03it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 16:  57%|█████▋    | 10645/18651 [08:29<07:14, 18.43it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 16:  82%|████████▏ | 15368/18651 [12:18<02:22, 23.02it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 16:  94%|█████████▍| 17592/18651 [14:03<00:50, 21.13it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 16: 100%|██████████| 18651/18651 [14:54<00:00, 20.86it/s]


Epoch [16/20], Loss: 3.0583


Epoch 17:  11%|█         | 2009/18651 [01:36<14:45, 18.80it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 17:  34%|███▎      | 6267/18651 [05:01<11:22, 18.13it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 17:  57%|█████▋    | 10646/18651 [08:32<07:36, 17.55it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 17:  82%|████████▏ | 15366/18651 [12:18<02:45, 19.80it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 17:  94%|█████████▍| 17593/18651 [14:06<00:52, 20.15it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 17: 100%|██████████| 18651/18651 [14:55<00:00, 20.82it/s]


Epoch [17/20], Loss: 3.0583


Epoch 18:  11%|█         | 2010/18651 [01:35<14:06, 19.66it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 18:  34%|███▎      | 6267/18651 [05:00<12:26, 16.59it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 18:  57%|█████▋    | 10645/18651 [08:30<08:40, 15.38it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 18:  82%|████████▏ | 15368/18651 [12:19<02:17, 23.88it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 18:  94%|█████████▍| 17593/18651 [14:04<00:52, 19.99it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 18: 100%|██████████| 18651/18651 [14:53<00:00, 20.88it/s]


Epoch [18/20], Loss: 3.0583


Epoch 19:  11%|█         | 2011/18651 [01:36<13:46, 20.13it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 19:  34%|███▎      | 6267/18651 [05:02<10:54, 18.92it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 19:  57%|█████▋    | 10645/18651 [08:33<07:42, 17.30it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 19:  82%|████████▏ | 15366/18651 [12:19<02:40, 20.46it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 19:  94%|█████████▍| 17592/18651 [14:07<00:53, 19.83it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 19: 100%|██████████| 18651/18651 [14:57<00:00, 20.77it/s]


Epoch [19/20], Loss: 3.0582


Epoch 20:  11%|█         | 2011/18651 [01:36<13:32, 20.48it/s]

Skipped example 2007 due to error: The size of tensor a (70) must match the size of tensor b (69) at non-singleton dimension 0


Epoch 20:  34%|███▎      | 6267/18651 [04:59<11:32, 17.88it/s]

Skipped example 6263 due to error: The size of tensor a (262) must match the size of tensor b (250) at non-singleton dimension 0


Epoch 20:  57%|█████▋    | 10645/18651 [08:28<07:42, 17.32it/s]

Skipped example 10642 due to error: The size of tensor a (345) must match the size of tensor b (268) at non-singleton dimension 0


Epoch 20:  82%|████████▏ | 15366/18651 [12:16<02:17, 23.95it/s]

Skipped example 15364 due to error: The size of tensor a (98) must match the size of tensor b (97) at non-singleton dimension 0


Epoch 20:  94%|█████████▍| 17592/18651 [14:02<00:50, 20.98it/s]

Skipped example 17588 due to error: The size of tensor a (274) must match the size of tensor b (273) at non-singleton dimension 0


Epoch 20: 100%|██████████| 18651/18651 [14:52<00:00, 20.90it/s]


Epoch [20/20], Loss: 3.0582
Training complete.


In [4]:
model = KingBert(distilbert_tuned=distilbert_tuned, albert_tuned=albert_tuned)
state_dict = torch.load('model_state.pth')
model.load_state_dict(state_dict)
model.eval()

  state_dict = torch.load('model_state.pth')


KingBert(
  (distilbert): DistilBertForTokenClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): DistilBertSdpaAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn

In [11]:
from sklearn.metrics import classification_report, confusion_matrix
from seqeval.metrics import classification_report as seqeval_classification_report
import torch
from tqdm import tqdm
import numpy as np

def evaluate_ensemble_model(model, dataset, device="cuda", label_pad_token_id=-100):
    model.eval()
    model.to(device)

    all_preds = []
    all_labels = []

    for i in tqdm(range(len(dataset)), desc="Evaluating"):
        try:
            item = dataset[i]

            # Convert and move to device
            d_ids  = torch.tensor(item["distilbert_inputids"]).unsqueeze(0).to(device)
            a_ids  = torch.tensor(item["albert_inputids"]).unsqueeze(0).to(device)
            d_msk  = torch.tensor(item["distilbert_attention_masks"]).unsqueeze(0).to(device)
            a_msk  = torch.tensor(item["albert_attention_masks"]).unsqueeze(0).to(device)
            d_wids = torch.tensor([-100] + item["distilbert_wordids"][1:-1] + [-100]).unsqueeze(0).to(device)
            a_wids = torch.tensor([-100] + item["albert_wordids"][1:-1] + [-100]).unsqueeze(0).to(device)
            labels = torch.tensor(item["spacy_labels"]).unsqueeze(0).to(device)

            # Run forward pass
            with torch.no_grad():
                log_probs = model(d_ids, a_ids, d_msk, a_msk, d_wids, a_wids)  # (1, W, C)

            preds = log_probs.argmax(-1)  # (1, W)
            preds = preds.squeeze(0)
            labels = labels.squeeze(0)

            # Truncate to match shortest sequence
            min_len = min(len(preds), len(labels))
            preds = preds[:min_len]
            labels = labels[:min_len]

            all_preds.append(preds.cpu().numpy().tolist())
            all_labels.append(labels.cpu().numpy().tolist())

        except Exception as e:
            print(f"Skipped example {i} due to error: {e}")
            continue

    # ---- Clean and prepare for evaluation ----
    flat_preds = [p for sublist in all_preds for p in sublist if p != label_pad_token_id]
    flat_labels = [l for sublist in all_labels for l in sublist if l != label_pad_token_id]

    print("Classification report:")
    print(classification_report(flat_labels, flat_preds, zero_division=0))

    cm = confusion_matrix(flat_labels, flat_preds)
    return {
        "precision": np.mean(np.diag(cm) / (cm.sum(0) + 1e-10)),
        "recall": np.mean(np.diag(cm) / (cm.sum(1) + 1e-10)),
        "f1": 2 * np.mean(np.diag(cm) / (cm.sum(1) + 1e-10)) * np.mean(np.diag(cm) / (cm.sum(0) + 1e-10)) / (
            np.mean(np.diag(cm) / (cm.sum(1) + 1e-10)) + np.mean(np.diag(cm) / (cm.sum(0) + 1e-10)) + 1e-10
        ),
        "accuracy": np.mean(np.array(flat_preds) == np.array(flat_labels)),
        "confusion_matrix": cm
    }


In [13]:
test_dataset = aux.json_to_Dataset_ensemble('datasets/ensemble_test.json')


res = evaluate_ensemble_model(model, test_dataset)
print(res)


  stacked_tensors1 = torch.stack([torch.tensor(i) for i in output1])
  stacked_tensors2 = torch.stack([torch.tensor(i) for i in output2])
Evaluating:  45%|████▌     | 1048/2320 [00:48<01:06, 19.24it/s]

Skipped example 1043 due to error: The size of tensor a (363) must match the size of tensor b (275) at non-singleton dimension 0


Evaluating:  86%|████████▋ | 2002/2320 [01:30<00:17, 18.47it/s]

Skipped example 1997 due to error: The size of tensor a (365) must match the size of tensor b (267) at non-singleton dimension 0


Evaluating: 100%|██████████| 2320/2320 [01:45<00:00, 22.05it/s]


Classification report:
              precision    recall  f1-score   support

           0       0.79      0.73      0.76       533
           1       0.69      0.59      0.64       536
           2       0.72      0.80      0.76       772
           3       0.74      0.22      0.34       446
           4       0.63      0.72      0.67       347
           5       0.64      0.53      0.58      1021
           6       0.44      0.44      0.44       837
           7       0.48      0.80      0.60       654
           8       0.65      0.68      0.67       374
           9       0.55      0.74      0.63       771
          10       0.58      0.46      0.51       887
          11       0.83      0.39      0.53       548
          12       0.58      0.59      0.58       429
          13       0.81      0.77      0.79       204
          14       0.46      0.41      0.44       759
          15       0.73      0.83      0.78      1631
          16       0.00      0.00      0.00         5
    