In [22]:
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel, EvalPrediction
from models.modeling_moebert import MoEBertForSentenceSimilarity, BertForSentenceSimilarity
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, hamming_loss, confusion_matrix
from models.load_model import MoEBertLoadWeights
from data_zoo import *
from utils import get_yaml, log_metrics
from trainer import HF_trainer

In [5]:
args = {
    'domains': ['TEST'],
    'new_special_tokens': True,
    'num_experts': 4,
    'topk': 2,
    'token_moe':False,
    'moe_type':'topk',
    'num_tasks':2,
    'wBAL':0.1,
    'wMI':0.1
}


tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
base_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
#loader = MoEBertLoadWeights(args, base_model=base_model, tokenizer=tokenizer)
#base_model, tokenizer = loader.get_seeded_model()
#model = MoEBertForSentenceSimilarity(args, base_model)
model = BertForSentenceSimilarity(base_model)
model

BertForSentenceSimilarity(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele

In [6]:
yargs = get_yaml('yamls/base.yaml')['training_args']

In [4]:
for i in range(20, 410, 20):
    print(i)

20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400


In [3]:
batch_sentences = [
    'Hello, how are you?',
    'I am doing well, thank you!',
    'The weather is nice today.',
]

ids = tokenizer(batch_sentences, return_tensors='pt', padding=True, truncation=True)
router_labels = torch.tensor([1, 0, 1])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [7]:
def get_datasets_train(data_paths, tokenizer, domains, max_length=512):
    train_a, train_b, train_c_label, train_r_label = [], [], [], []
    valid_a, valid_b, valid_c_label, valid_r_label = [], [], [], []
    test_a, test_b, test_c_label, test_r_label = [], [], [], []
    for i, data_path in enumerate(data_paths):
        dataset = load_dataset(data_path)
        train = dataset['train']
        valid = dataset['valid']
        test = dataset['test']
        train_a.extend(train['a'])
        train_b.extend(train['b'])
        train_c_label.extend(train['label'])
        train_r_label.extend([i] * len(train['label']))
        valid_a.extend(valid['a'])
        valid_b.extend(valid['b'])
        valid_c_label.extend(valid['label'])
        valid_r_label.extend([i] * len(valid['label']))
        test_a.extend(test['a'])
        test_b.extend(test['b'])
        test_c_label.extend(test['label'])
        test_r_label.extend([i] * len(test['label']))
    random.shuffle(valid_c_label)
    train_dataset = TextDataset(train_a, train_b, train_c_label, train_r_label, tokenizer, domains, max_length)
    valid_dataset = TextDataset(valid_a[:40], valid_b[:40], valid_c_label[:40], valid_r_label[:40], tokenizer, domains, max_length)
    test_dataset = TextDataset(test_a, test_b, test_c_label, test_r_label, tokenizer, domains, max_length)
    return train_dataset, valid_dataset, test_dataset

In [8]:
train_dataset, valid_dataset, test_dataset = get_datasets_train(['lhallee/abstract_domain_skincancer'], tokenizer, ['TEST'], 64)

In [9]:
def data_collator(features):
    batch = {key: torch.stack([f[key] for f in features]) for key in features[0]}
    return batch


def calculate_max_metrics(ss, labels, cutoff):
    ss, labels = ss.float(), labels.float()
    tp = torch.sum((ss >= cutoff) & (labels == 1.0))
    fp = torch.sum((ss >= cutoff) & (labels == 0.0))
    fn = torch.sum((ss < cutoff) & (labels == 1.0))
    precision_denominator = tp + fp
    precision = torch.where(precision_denominator != 0, tp / precision_denominator, torch.tensor(0.0))
    recall_denominator = tp + fn
    recall = torch.where(recall_denominator != 0, tp / recall_denominator, torch.tensor(0.0))
    f1 = torch.where((precision + recall) != 0, (2 * precision * recall) / (precision + recall), torch.tensor(0.0))
    return f1, precision, recall


def max_metrics(ss, labels, increment=0.01):
    ss = torch.clamp(ss, -1.0, 1.0)
    min_val = ss.min().item()
    max_val = 1
    if min_val >= max_val:
        min_val = 0
    cutoffs = torch.arange(min_val, max_val, increment)
    metrics = [calculate_max_metrics(ss, labels, cutoff.item()) for cutoff in cutoffs]
    f1s = torch.tensor([metric[0] for metric in metrics])
    precs = torch.tensor([metric[1] for metric in metrics])
    recalls = torch.tensor([metric[2] for metric in metrics])
    valid_f1s = torch.where(torch.isnan(f1s), torch.tensor(-1.0), f1s)  # Replace NaN with -1 to ignore them in argmax
    max_index = torch.argmax(valid_f1s)
    return f1s[max_index].item(), precs[max_index].item(), recalls[max_index].item(), cutoffs[max_index].item()


def compute_metrics_sentence_similarity(p: EvalPrediction):
    preds = p.predictions
    labels = p.label_ids[-1]

    print(preds[0].shape, preds[1].shape)
    print(labels)

    emb_a, emb_b = preds[0], preds[1]
    # Convert embeddings to tensors
    emb_a_tensor = torch.tensor(emb_a)
    emb_b_tensor = torch.tensor(emb_b)
    labels_tensor = torch.tensor(labels)

    # Compute cosine similarity between the embeddings
    cosine_sim = F.cosine_similarity(emb_a_tensor, emb_b_tensor)
    # Compute max metrics
    f1, prec, recall, thres = max_metrics(cosine_sim, labels_tensor)
    # Compute accuracy based on the threshold found
    predictions = (cosine_sim > thres).float()
    acc = accuracy_score(predictions.flatten().numpy(), labels.flatten())
    # Compute the mean absolute difference between cosine similarities and labels
    dist = torch.mean(torch.abs(cosine_sim - labels_tensor)).item()
    # Return a dictionary of the computed metrics
    return {
        'accuracy': acc,
        'f1_max': f1,
        'precision_max': prec,
        'recall_max': recall,
        'threshold': thres,
        'distance': dist
    }


In [20]:
trainer = HF_trainer(model, train_dataset, valid_dataset,
                     compute_metrics=compute_metrics_sentence_similarity, data_collator=data_collator, **yargs)
predictions, label_ids, metrics = trainer.predict(valid_dataset)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/2 [00:00<?, ?it/s]

(40, 768) (40, 768)
[1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.
 1. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0.]
0.596491277217865 0.42500001192092896 1.0 0.46585744619369507 0.4 0.5347017645835876


In [23]:
log_metrics('test.txt', metrics)

In [24]:
metrics

{'test_loss': 52.1240119934082,
 'test_accuracy': 0.4,
 'test_f1_max': 0.596491277217865,
 'test_precision_max': 0.42500001192092896,
 'test_recall_max': 1.0,
 'test_threshold': 0.46585744619369507,
 'test_distance': 0.5347017645835876,
 'test_runtime': 0.6754,
 'test_samples_per_second': 59.22,
 'test_steps_per_second': 2.961}

In [3]:
import torch
import torch.nn as nn

In [1]:
class config:
    hidden_size = 4
    intermediate_size = 8
    hidden_dropout_prob = 0.0

In [4]:
class BertExpert1(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.intermediate_up = nn.Linear(config.hidden_size, config.intermediate_size) # BertIntermediate dense
        self.intermediate_down = nn.Linear(config.intermediate_size, config.hidden_size) # BertOutput dense
        self.new_linear = nn.Linear(config.hidden_size, config.intermediate_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.act = nn.GELU()
    
    def forward(self, hidden_states):
        hidden_states = self.act(self.intermediate_up(hidden_states)) * self.new_linear(hidden_states)
        hidden_states = self.dropout(self.intermediate_down(hidden_states))
        return hidden_states
    

class BertExpert2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.intermediate_up = nn.Linear(config.hidden_size, config.intermediate_size) # BertIntermediate dense
        self.intermediate_down = nn.Linear(config.intermediate_size, config.hidden_size) # BertOutput dense
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.act = nn.GELU()

        self.new_linear = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
        self.new_linear.weight.data.zero_()
        self.new_linear.bias.data.fill_(1.0)
    
    def forward(self, hidden_states):
        hidden_states = self.act(self.intermediate_up(hidden_states)) * self.new_linear(hidden_states)
        hidden_states = self.dropout(self.intermediate_down(hidden_states))
        return hidden_states

In [9]:
hidden_states = torch.rand(2, 16, config.hidden_size)

In [10]:
intermediate_up = nn.Linear(config.hidden_size, config.intermediate_size) # BertIntermediate dense
intermediate_down = nn.Linear(config.intermediate_size, config.hidden_size) # BertOutput dense
new_linear = nn.Linear(config.hidden_size, config.intermediate_size)
new_linear.weight.data.zero_()
new_linear.bias.data.fill_(1.0)
out1 = intermediate_up(hidden_states) * new_linear(hidden_states)
out2 = intermediate_up(hidden_states)

In [11]:
out1 == out2

tensor([[[True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True]],

        [[True, True, True, True, True, True, True, Tr