In [1]:
### This script explores using switch transformer layers to do snli
### its goal is to make a Perceiver IO using switch transformers as substitutes
### for the attention layers of the hidden states after cross attention

In [2]:
from transformers import AutoTokenizer, SwitchTransformersModel, BertModel

In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
model = SwitchTransformersModel.from_pretrained("google/switch-base-8")

input_ids = tokenizer(
    "Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids  # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1

# preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel.
# This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg.
decoder_input_ids = model._shift_right(decoder_input_ids)

# forward pass
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
last_hidden_states = outputs.last_hidden_state



In [4]:
### Based on the cell output below, a SwitchTransformerBlock has many switch transformer layers, a pattern of normal, expert, normal, expert exists
### where the normal layer have only 1 expert and the expert layer has 8 experts and a router
### the plan is to replace a bert layer with a single switch transformer layer and see if the forward function works

In [1]:
model.encoder.block

NameError: name 'model' is not defined

In [6]:
switch_transformer_model = SwitchTransformersModel.from_pretrained("google/switch-base-8")
bert_model = BertModel.from_pretrained("bert-base-uncased")



In [7]:
### verify forward works

## use switch transformer's tokenizer for both models
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")

## verfy switch transformer
input_ids = tokenizer(
    "Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids  # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
decoder_input_ids = model._shift_right(decoder_input_ids)
switch_transformer_outputs = switch_transformer_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

## verify bert
bert_outputs = bert_model(input_ids=input_ids)

In [8]:
### modify a bert to have a switch transformer layer
### just put these models here in case of network issues

switch_bert_model = BertModel.from_pretrained("bert-base-uncased")
switch_base_model = SwitchTransformersModel.from_pretrained("google/switch-base-8")



In [9]:
### modify a bert to have a switch transformer layer

## single out a switch transformer self-attention layer containing 8 experts and a router
transfer_base = switch_base_model.encoder.block[1]

## identify a recipient on bert model
recipient_base = switch_bert_model.encoder.layer[1]

import copy

## do a simple transfer experiment
switch_bert_model.encoder.layer[1] = copy.deepcopy(transfer_base)



In [11]:
### see if input still works

switch_bert_outputs = switch_bert_model(input_ids=input_ids)

In [12]:
switch_bert_outputs

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.4282, -0.6921, -0.2139,  ..., -0.3058,  0.6199, -0.6951],
         [-0.3807, -0.6976, -0.2409,  ..., -0.3350,  0.6382, -0.8151],
         [-0.3288, -0.8514, -0.1912,  ..., -0.2857,  0.6270, -0.8177],
         ...,
         [-0.3837, -0.7118, -0.2385,  ..., -0.3238,  0.6277, -0.7626],
         [-0.4135, -0.6921, -0.2606,  ..., -0.3207,  0.6316, -0.7488],
         [-0.3137, -0.8288, -0.1897,  ..., -0.2950,  0.6248, -0.8114]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.7599, -0.4874,  0.8528,  0.7450,  0.0988, -0.2571,  0.8354,  0.4930,
          0.7938, -0.8667,  0.2156, -0.0879,  0.9749, -0.7692,  0.9851, -0.1776,
          0.1712, -0.3811, -0.0149, -0.8024,  0.9498,  0.6904,  0.5803,  0.3136,
          0.2781,  0.6882, -0.5676,  0.9817,  0.9506,  0.9434, -0.6978,  0.0972,
         -0.9955,  0.0678,  0.6911, -0.8933, -0.0234, -0.6484, -0.0695, -0.0486,
         -0.9667, -0.1398,  0.96