In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
from aux import json_to_Dataset_ensemble
import aux
from tqdm import tqdm
import traceback

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
from aux import ensembler, json_to_Dataset_ensemble
from tqdm import tqdm

class KingBert(nn.Module):
    def __init__(self, distilbert_tuned, albert_tuned):
        super().__init__()
        self.distilbert = distilbert_tuned
        self.albert = albert_tuned

        for distilbert_param in self.distilbert.parameters():
            distilbert_param.requires_grad = False

        for albert_param in self.albert.parameters():
            albert_param.requires_grad = False 
        
        # Here we have an alpha for each label
        self.alpha = nn.Parameter(0.5 * torch.ones(47), requires_grad=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids):
        distilbert_output = self.distilbert(input_ids=distilbert_input_ids, attention_mask=distil_attention_mask)
        albert_output = self.albert(input_ids=albert_input_ids, attention_mask=alb_attention_mask)
        distilbert_fixed, albert_fixed = aux.ensembler(distilbert_output['logits'].squeeze(), albert_output['logits'].squeeze(), distilbert_word_ids.squeeze(), albert_word_ids.squeeze())

        distilbert_fixed = self.softmax(distilbert_fixed)
        albert_fixed = self.softmax(albert_fixed)

        final_output = distilbert_fixed * self.alpha + albert_fixed * (torch.ones(47) - self.alpha)

        return self.softmax(final_output)

# Load Huggingface dataset
train_dataset = json_to_Dataset_ensemble('ensemble_train.json')

# Load pre-trained models (distilbert and albert)
distilbert_tuned = AutoModelForTokenClassification.from_pretrained('distilbert_finetuned')
albert_tuned = AutoModelForTokenClassification.from_pretrained('albert_finetuned')

kingbert_model = KingBert(distilbert_tuned, albert_tuned)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kingbert_model.parameters(), lr=2e-5)

# Training loop
num_epochs = 2

for epoch in range(num_epochs):
    total_loss = 0
    for i in tqdm(range(len(train_dataset)), desc="Steps in epoch"):
        try:
            # Get the individual item from the dataset
            item = train_dataset[i]
            
            distilbert_input_ids = torch.tensor(item['distilbert_inputids']).unsqueeze(0)
            albert_input_ids = torch.tensor(item['albert_inputids']).unsqueeze(0)
            distil_attention_mask = torch.tensor(item['distilbert_attention_masks']).unsqueeze(0)
            alb_attention_mask = torch.tensor(item['albert_attention_masks']).unsqueeze(0)
            distilbert_word_ids = torch.tensor([-100] + item['distilbert_wordids'][1:-1] + [-100]).unsqueeze(0)
            albert_word_ids = torch.tensor([-100] + item['albert_wordids'][1:-1] + [-100]).unsqueeze(0)
            targets = torch.tensor(item['spacy_labels']).unsqueeze(0)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            output = kingbert_model(distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids)
            
            ohe_targets = torch.zeros(output.shape[0], output.shape[1])
            for i,j in enumerate(targets):
                ohe_targets[i][j] = 1
            # Compute loss
            loss = criterion(output, ohe_targets)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        except:
            continue
        
    avg_loss = total_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

torch.save(kingbert_model.state_dict(), 'model_state.pth')

print('Training complete.')


  stacked_tensors1 = torch.stack([torch.tensor(i) for i in output1])
  stacked_tensors2 = torch.stack([torch.tensor(i) for i in output2])


torch.Size([125, 47])
torch.Size([85, 47])
torch.Size([97, 47])
torch.Size([77, 47])
torch.Size([87, 47])
torch.Size([89, 47])
torch.Size([66, 47])
torch.Size([54, 47])
torch.Size([85, 47])
torch.Size([68, 47])
torch.Size([109, 47])
torch.Size([86, 47])
torch.Size([117, 47])
torch.Size([75, 47])
torch.Size([69, 47])


Epoch:   0%|          | 0/1141 [00:01<?, ?it/s]

torch.Size([85, 47])
[125, 85, 98, 77, 87, 89, 66, 54, 85, 68, 109, 86, 117, 75, 69, 85]
torch.Size([1374, 47])
1375
Error during batch processing: Expected input batch_size (1374) to match target batch_size (16).
Traceback (most recent call last):
  File "/var/folders/05/8k53g1bs725dn8cs5310zydm0000gn/T/ipykernel_77993/1711267779.py", line 145, in <module>
    loss = criterion(output, targets)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modul


