In [11]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, BertForSequenceClassification

In [91]:
def f1_max(pred, target):
    """
        F1 score with the optimal threshold, adapted from TorchDrug.

        This function first enumerates all possible thresholds for deciding positive and negative
        samples, and then pick the threshold with the maximal F1 score.

        Parameters:
            pred (Tensor): predictions of shape :math:`(B, N)`
            target (Tensor): binary targets of shape :math:`(B, N)`
      """

    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1]
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - \
                    torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1])
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - \
                torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1])
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    max_index = all_f1.argmax()
    max_f1 = all_f1[max_index]
    max_threshold = pred.flatten()[max_index]
    return max_f1.item(), max_threshold.item() # outputs f1max and threshold to get max f1

(0.7291666865348816, 0.5378358364105225)

In [None]:
class MiniMoELoadWeights:
    def __init__(self, args):
        self.model_path = args.model_path
        self.model_type = args.model_type
        self.num_local_experts = args.num_local_experts
        self.num_experts_per_tok = args.num_experts_per_tok
        self.use_router_loss = args.use_router_loss
        self.output_hidden_states = args.output_hidden_states
        self.num_labels = args.num_labels
        self.Bert_base = None
        self.config = None

    def get_seeded_model(self):
        start_time = time.time()
        if self.model_type == 'Model':
            from transformers import BertModel
            self.Bert_base = BertModel.from_pretrained(self.model_path)
            self.config = self.get_config(self.Bert_base)
            model = BertModel(config=self.config)
        elif self.model_type == 'MaskedLM':
            from transformers import BertForMaskedLM
            self.Bert_base = BertModel.from_pretrained(self.model_path)
            self.config = self.get_config(self.Bert_base)
            model = BertForMaskedLM(config=self.config)
        elif self.model_type == 'SequenceClassification':
            from transformers import BertForSequenceClassification
            self.Bert_base = BertModel.from_pretrained(self.model_path, num_labels=self.num_labels)
            self.config = self.get_config(self.Bert_base)
            model = BertForSequenceClassification(config=self.config)
        elif self.model_type == 'TokenClassification':
            from transformers import BertForTokenClassification
            self.Bert_base = BertModel.from_pretrained(self.model_path, num_labels=self.num_labels)
            self.config = self.get_config(self.Bert_base)
            model = BertForTokenClassification(config=self.config)
        else: print(f'You entered {self.model_type}\nValid options are:\nModel , MaskedLM , SequenceClassification , TokenClassification')
        model = self.match_weights(model)
        end_time = time.time()
        print('Model loaded in ', round((end_time - start_time) / 60, 2), 'minutes')
        total, effective, mem = self.count_parameters(model)
        print(f'{total} million total parameters')
        print(f'{effective} million effective parameters')
        print(f'Approximately {mem} GB of memory in fp32\n')
        return model

    def get_config(self, model):
        config = model.config
        config.num_local_experts = self.num_local_experts
        config.num_experts_per_tok = self.num_experts_per_tok
        config.use_router_loss = self.use_router_loss
        config.output_router_logits = self.use_router_loss
        config.output_hidden_states = self.output_hidden_states
        return config

    def check_for_match(self, model): # Test for matching parameters
        all_weights_match = True
        for name, param in self.Bert_base.named_parameters(): # for shared parameters
            if name in model.state_dict():
                pre_trained_weight = param.data
                moe_weight = model.state_dict()[name].data
                if not torch.equal(pre_trained_weight, moe_weight):
                    all_weights_match = False
                    break
    
        for i in range(self.config.num_hidden_layers): # for experts
            for j in range(self.config.num_local_experts):
                moe_encoder_layer = model.Bert.encoder.layer[i] if self.model_type != 'Model' else model.encoder.layer[i]
                Bert_encoder_layer = self.Bert_base.Bert.encoder.layer[i] if self.model_type != 'Model' else self.Bert_base.encoder.layer[i] 
                if not torch.equal(moe_encoder_layer.moe_block.experts[j].intermediate_up.weight,
                                Bert_encoder_layer.intermediate.dense.weight):
                    all_weights_match = False
                if not torch.equal(moe_encoder_layer.moe_block.experts[j].intermediate_down.weight,
                                Bert_encoder_layer.output.dense.weight):
                    all_weights_match = False

        if all_weights_match:
            print('All weights match')
        else:
            print('Some weights differ')

    def match_weights(self, model): # Seeds MoBert experts with linear layers of Bert
        self.check_for_match(model)
        for name1, param1 in self.Bert_base.named_parameters():
            for name2, param2 in model.named_parameters():
                if name1 == name2:
                    model.state_dict()[name2].data.copy_(param1.data)

        for i in range(self.config.num_hidden_layers):
            for j in range(self.config.num_local_experts):
                moe_encoder_layer = model.Bert.encoder.layer[i] if self.model_type != 'Model' else model.encoder.layer[i]
                Bert_encoder_layer = self.Bert_base.Bert.encoder.layer[i] if self.model_type != 'Model' else self.Bert_base.encoder.layer[i] 
                moe_encoder_layer.moe_block.experts[j].intermediate_up = copy.deepcopy(Bert_encoder_layer.intermediate.dense)
                moe_encoder_layer.moe_block.experts[j].intermediate_down = copy.deepcopy(Bert_encoder_layer.output.dense)
        self.check_for_match(model)
        return model

    def count_parameters_in_layer(self, layer):
        """Counts parameters in a regular layer."""
        return sum(p.numel() for p in layer.parameters())

    def count_parameters(self, model):
        total_params = sum(p.numel() for p in model.parameters())
        non_effective_params = 0
        for i in range(self.config.num_hidden_layers):
            for j in range(self.config.num_local_experts - self.config.num_experts_per_tok):
                moe_encoder_layer = model.encoder.layer[i] if self.model_type == 'Model' else model.Bert.encoder.layer[i]
                non_effective_params += self.count_parameters_in_layer(moe_encoder_layer.moe_block.experts[j].intermediate_up)
                non_effective_params += self.count_parameters_in_layer(moe_encoder_layer.moe_block.experts[j].intermediate_down)
        effective_params = total_params - non_effective_params
        memory_bytes = total_params * 4  # 4 bytes for 32-bit floats
        memory_gig = round(memory_bytes / (1024 ** 3), 2)
        return round(total_params / 1e6, 1), round(effective_params / 1e6, 1), memory_gig

In [92]:
class BertExpert(nn.Module):
    """
    Combined Esm intermediate and output linear layers for MOE
    """
    def __init__(self, config):
        super().__init__()
        self.intermediate_up = nn.Linear(config.hidden_size, config.intermediate_size) # EsmIntermediate dense
        self.intermediate_down = nn.Linear(config.intermediate_size, config.hidden_size) # EsmOutput dense
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.act = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = self.intermediate_up(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.intermediate_down(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class BertMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.num_experts = config.num_experts
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        self.experts = nn.ModuleList([BertExpert(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        router_output = self.gate(hidden_states) # (batch, sequence_length, n_experts)
        router_logits = router_output.mean(dim=1) # (batch, n_experts)
        router_choice = F.softmax(router_logits, dim=-1).argmax(dim=-1) # (batch)
        final_hidden_states = torch.stack([self.experts[router_choice[i]](hidden_states[i]) for i in range(len(hidden_states))])
        return final_hidden_states, router_logits # (batch, sequence_length, hidden_dim), (batch, num_experts)

In [12]:
model = BertForSequenceClassification.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
tokenizer = BertTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
config = model.config
config.num_experts = 2

BertMoeBlock(
  (gate): Linear(in_features=384, out_features=2, bias=False)
  (experts): ModuleList(
    (0-1): 2 x BertExpert(
      (intermediate_up): Linear(in_features=384, out_features=1536, bias=True)
      (intermediate_down): Linear(in_features=1536, out_features=384, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (act): GELU(approximate='none')
    )
  )
)

In [93]:
ex = torch.rand(8, 16, 384)
block = BertMoeBlock(config)
test = block(ex)

Router choice:  torch.Size([8]) tensor([1, 0, 0, 0, 0, 1, 0, 0])


In [95]:
test[0].shape

torch.Size([8, 16, 384])

In [19]:
model.bert.encoder.layer[0].intermediate

BertIntermediate(
  (dense): Linear(in_features=384, out_features=1536, bias=True)
  (intermediate_act_fn): GELUActivation()
)