In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
import aux
from tqdm import tqdm
import traceback

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
from aux import ensembler, json_to_Dataset_ensemble
from tqdm import tqdm

class KingBert(nn.Module):
    def __init__(self, distilbert_tuned, albert_tuned):
        super().__init__()
        self.distilbert = distilbert_tuned
        self.albert = albert_tuned

        for distilbert_param in self.distilbert.parameters():
            distilbert_param.requires_grad = False

        for albert_param in self.albert.parameters():
            albert_param.requires_grad = False 
        
        self.alpha = nn.Parameter(0.5 * torch.ones(47), requires_grad=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids):
        distilbert_output = self.distilbert(input_ids=distilbert_input_ids, attention_mask=distil_attention_mask)
        albert_output = self.albert(input_ids=albert_input_ids, attention_mask=alb_attention_mask)
        distilbert_fixed, albert_fixed = aux.ensembler(distilbert_output['logits'].squeeze(), albert_output['logits'].squeeze(), distilbert_word_ids.squeeze(), albert_word_ids.squeeze())

        distilbert_fixed = self.softmax(distilbert_fixed)
        albert_fixed = self.softmax(albert_fixed)

        final_output = distilbert_fixed * self.alpha + albert_fixed * (torch.ones(47) - self.alpha)

        return self.softmax(final_output)

train_dataset = aux.json_to_Dataset_ensemble('data/ensemble_train.json')

distilbert_tuned = AutoModelForTokenClassification.from_pretrained('distilbert_finetuned')
albert_tuned = AutoModelForTokenClassification.from_pretrained('albert_finetuned')

kingbert_model = KingBert(distilbert_tuned, albert_tuned)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(kingbert_model.parameters(), lr=2e-5)


num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for i in tqdm(range(len(train_dataset)), desc="Steps in epoch"):
        try:

            item = train_dataset[i]
            
            distilbert_input_ids = torch.tensor(item['distilbert_inputids']).unsqueeze(0)
            albert_input_ids = torch.tensor(item['albert_inputids']).unsqueeze(0)
            distil_attention_mask = torch.tensor(item['distilbert_attention_masks']).unsqueeze(0)
            alb_attention_mask = torch.tensor(item['albert_attention_masks']).unsqueeze(0)
            distilbert_word_ids = torch.tensor([-100] + item['distilbert_wordids'][1:-1] + [-100]).unsqueeze(0)
            albert_word_ids = torch.tensor([-100] + item['albert_wordids'][1:-1] + [-100]).unsqueeze(0)
            targets = torch.tensor(item['spacy_labels']).unsqueeze(0)
            
            optimizer.zero_grad()
            
            output = kingbert_model(distilbert_input_ids, albert_input_ids, distil_attention_mask, alb_attention_mask, distilbert_word_ids, albert_word_ids)
            
            ohe_targets = torch.zeros(output.shape[0], output.shape[1])
            for i,j in enumerate(targets):
                ohe_targets[i][j] = 1

            loss = criterion(output, ohe_targets)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        except:
            continue
        
    avg_loss = total_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

torch.save(kingbert_model.state_dict(), 'model_state.pth')

print('Training complete.')


  stacked_tensors1 = torch.stack([torch.tensor(i) for i in output1])
  stacked_tensors2 = torch.stack([torch.tensor(i) for i in output2])
Steps in epoch: 100%|██████████| 18244/18244 [38:26<00:00,  7.91it/s] 


Epoch [1/5], Loss: 0.2237


Steps in epoch: 100%|██████████| 18244/18244 [40:15<00:00,  7.55it/s]


Epoch [2/5], Loss: 0.2237


Steps in epoch: 100%|██████████| 18244/18244 [42:33<00:00,  7.14it/s] 


Epoch [3/5], Loss: 0.2237


Steps in epoch: 100%|██████████| 18244/18244 [42:36<00:00,  7.14it/s] 


Epoch [4/5], Loss: 0.2237


Steps in epoch: 100%|██████████| 18244/18244 [43:26<00:00,  7.00it/s] 


Epoch [5/5], Loss: 0.2237
Training complete.
