In [1]:
import aux
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, AutoTokenizer, DataCollatorForTokenClassification
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from kingbert import KingBert
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = AutoModelForTokenClassification.from_pretrained('distilbert_finetuned')
tokenizer = AutoTokenizer.from_pretrained('distilbert_finetuned')

In [3]:
model2 = AutoModelForTokenClassification.from_pretrained('albert_finetuned')
tokenizer2 = AutoTokenizer.from_pretrained('albert_finetuned')

In [4]:
data = aux.json_to_Dataset_ensemble("data/ensemble_train.json")

In [8]:
albert_logits = []
distilbert_logits = []
labels = []

for datum in tqdm(data):
    logits_albert, predictions, predicted_token_class, inputs = aux.inference(model2, torch.tensor([datum['albert_inputids']]), torch.tensor([datum['albert_attention_masks']]))
    logits_distilbert, predictions, predicted_token_class, inputs = aux.inference(model, torch.tensor([datum['distilbert_inputids']]), torch.tensor([datum['distilbert_attention_masks']]))
    albert_output, distilbert_output = aux.ensembler(logits_albert.squeeze(), logits_distilbert.squeeze(), datum['albert_wordids'], datum['distilbert_wordids'])
    albert_output = torch.softmax(albert_output, dim=1)
    distilbert_output = torch.softmax(distilbert_output, dim=1)
    if distilbert_output.shape != albert_output.shape:
        continue
    if albert_output.shape[0] != len(datum['spacy_labels']):
        continue
    albert_logits += [row.tolist() for row in albert_output]
    distilbert_logits += [row.tolist() for row in distilbert_output]
    spacy_labels = []
    for i in datum['spacy_labels']:
        ohe_labels = [0 for i in range(47)]
        ohe_labels[i] = 1
        spacy_labels.append(ohe_labels)
    labels += spacy_labels
    assert len(albert_logits) == len(distilbert_logits)
    assert len(labels) == len(albert_logits)

100%|██████████| 18244/18244 [41:53<00:00,  7.26it/s] 


In [12]:
len(labels) == len(distilbert_logits) == len(albert_logits)

True

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define your model
class Ensembler(nn.Module):
    def __init__(self):
        super(Ensembler, self).__init__()
        self.alpha = nn.Parameter(0.5 * torch.ones(47), requires_grad=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        final_output = x1 * self.alpha + x2 * (torch.ones(47) - self.alpha)
        return self.softmax(final_output)

# Define your loss function
criterion = nn.CrossEntropyLoss()

# Define your model and optimizer
ensembler = Ensembler()
optimizer = optim.Adam(model.parameters(), lr=2e-5)

# Convert your data to torch tensors if they aren't already
albert_data = albert_logits
distilbert_data = distilbert_logits
labels_data = labels


# Convert them to PyTorch tensors
albert_inputs = torch.tensor(albert_logits, dtype=torch.float32)
distilbert_inputs = torch.tensor(distilbert_logits, dtype=torch.float32)
labels_outputs = torch.tensor(labels_data, dtype=torch.float32)

# Create a TensorDataset and DataLoader for batching
dataset = TensorDataset(albert_inputs, distilbert_inputs, labels_outputs)
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop with batches
epochs = 10
for epoch in range(epochs):
    for batch_albert_inputs, batch_distilbert_inputs, batch_labels in tqdm(dataloader):
        # Forward pass
        outputs = ensembler(batch_distilbert_inputs, batch_albert_inputs)
        loss = criterion(outputs, batch_labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')


100%|██████████| 82181/82181 [00:18<00:00, 4523.20it/s]


Epoch [1/10], Loss: 3.8216


100%|██████████| 82181/82181 [00:18<00:00, 4494.66it/s]


Epoch [2/10], Loss: 3.8189


100%|██████████| 82181/82181 [00:18<00:00, 4470.67it/s]


Epoch [3/10], Loss: 3.8215


100%|██████████| 82181/82181 [00:18<00:00, 4370.01it/s]


Epoch [4/10], Loss: 3.8216


100%|██████████| 82181/82181 [00:20<00:00, 4074.72it/s]


Epoch [5/10], Loss: 3.8216


100%|██████████| 82181/82181 [00:21<00:00, 3788.61it/s]


Epoch [6/10], Loss: 3.8157


100%|██████████| 82181/82181 [00:22<00:00, 3609.17it/s]


Epoch [7/10], Loss: 3.8215


100%|██████████| 82181/82181 [00:23<00:00, 3518.16it/s]


Epoch [8/10], Loss: 3.8157


100%|██████████| 82181/82181 [00:23<00:00, 3502.90it/s]


Epoch [9/10], Loss: 3.8157


100%|██████████| 82181/82181 [00:22<00:00, 3576.43it/s]

Epoch [10/10], Loss: 3.8157





In [28]:
ensembler.alpha

Parameter containing:
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000], requires_grad=True)