In [1]:
### This script explores using bert weights of FFN to initialize switch transformer's FFN
### Building on switch_transformer_snli_transfer_exp_bert.ipynb

In [34]:
from transformers import AutoTokenizer, SwitchTransformersModel, BertModel

In [35]:
### modify a bert to have a switch transformer layer
### just put these models here in case of network issues

switch_bert_model = BertModel.from_pretrained("bert-base-uncased")
switch_base_model = SwitchTransformersModel.from_pretrained("google/switch-base-8")



In [36]:
### modify a bert to have a switch transformer layer

## single out a switch transformer self-attention layer containing 8 experts and a router
transfer_base = switch_base_model.encoder.block[1]

## identify a recipient on bert model
recipient_base = switch_bert_model.encoder.layer[1]

In [None]:
### Based on the structures indicated in the cell below
### each switch expert RELU activation, wi of (in=768, out=3072, bias=False) and wo of (in=3072, out=768, bias=False)
### while each bert layer has intermediate with (in=768, out=3072, bias=True) and GELU, and output with (in=3027, out=768, bias=True)
### it seems that these two architectures are fit for direct transfer

In [37]:
transfer_base
recipient_base

BertLayer(
  (attention): BertAttention(
    (self): BertSdpaSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [40]:
### list components for transfor

component_transfer_base, component_recipient_base = recipient_base, transfer_base
component_transfer_base_wi = component_transfer_base.intermediate.dense
component_transfer_base_wo = component_transfer_base.output.dense
component_recipient_base_wi = {k: v.wi for k, v in component_recipient_base.layer[1].mlp.experts.items()}
component_recipient_base_wo = {k: v.wo for k, v in component_recipient_base.layer[1].mlp.experts.items()}

### do the actual transfer
for k in component_recipient_base.layer[1].mlp.experts:
    component_recipient_base.layer[1].mlp.experts[k].wi = component_transfer_base_wi
    component_recipient_base.layer[1].mlp.experts[k].wo = component_transfer_base_wo

### and then get this component onto the recipient

import copy
switch_bert_model.encoder.layer[1] = copy.deepcopy(transfer_base)

In [42]:
### see if forward works

## use switch transformer's tokenizer for both models
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")

input_ids = tokenizer(
    "Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids  # Batch size 1

## verify bert
switch_bert_outputs = switch_bert_model(input_ids=input_ids)

