In [None]:
%run local_prot_vec.py

In [None]:
%run main.py --yaml_path yamls/protein_vec.yaml

In [None]:
import torch

def are_models_equal(model1, model2):
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()
    
    if len(state_dict1) != len(state_dict2):
        print('Length')
        return False
    
    for key in state_dict1:
        if key not in state_dict2:
            print('Key')
            return False
        
        if not torch.equal(state_dict1[key], state_dict2[key]):
            print('Equal')
            return False
    
    return True


if are_models_equal(disk_model.cpu(), hf_model.cpu()):
    print("The models are equal.")
else:
    print("The models are not equal.")

In [None]:
from data.load_data import get_datasets_test_triplet
from transformers import T5Tokenizer

In [None]:
args = {
    'data_paths': ['lhallee/triplets'],
    'domains': ['[EC]', '[CC]', '[MF]', '[BP]', '[CC]', '[CC]', '[IP]'],
    'new_special_tokens': False,
    'max_length':512, 
    'p_col': 'positives',
    'a_col': 'anchors',
    'n_col': 'negatives',
    'label_col': 'aspects',
    'model_type': 'ProteinVec'
}

In [None]:
tokenizer = T5Tokenizer.from_pretrained('lhallee/ProteinVec')

In [None]:
triplet_datasets = get_datasets_test_triplet(args, tokenizer)

In [None]:
ec = triplet_datasets[0]
ec

In [None]:
import torch

torch.isnan(torch.tensor(None))

In [None]:
import torch
import torch.nn as nn
from models.modeling_moesm import EsmExpert


class SentenceEnforcedSwitchMoeBlock(nn.Module): ### Test
    def __init__(self, config, expert):
        """
        Sentence level MoE, single expert chosen
        """
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.num_experts = config.num_experts
        self.experts = nn.ModuleList([expert(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor, router_labels: torch.tensor) -> torch.Tensor:
        # (batch, seq_len, hidden_size), (batch,) -> from 0 to num_experts-1
        sorted_indices = torch.argsort(router_labels) # sort in order of expert idx
        hidden_states = hidden_states[sorted_indices] # apply sort
        router_labels = router_labels[sorted_indices] # apply sort
        expert_idxs = torch.unique(router_labels) # find all experts needed
        grouped_hidden_states = torch.split(hidden_states, tuple(torch.bincount(router_labels))) # split sorted hidden_states

        expert_outputs = []
        for idx, group in zip(expert_idxs, grouped_hidden_states):
            expert_output = self.experts[idx](group) # sne batched groups to their experts
            expert_outputs.append(expert_output)

        concatenated_outputs = torch.cat(expert_outputs, dim=0)
        final_hidden_states = concatenated_outputs[torch.argsort(sorted_indices)] # put back to original order
        return final_hidden_states  # (batch, sequence_length, hidden_dim)
    
class Config:
    def __init__(self):
        self.hidden_size = 128
        self.num_experts = 4
        self.intermediate_size = 256
        self.hidden_dropout_prob = 0.0


config = Config()
moe_block = SentenceEnforcedSwitchMoeBlock(config, EsmExpert)




In [None]:
# Generate random tensors for testing
batch_size = 8
seq_length = 16
hidden_size = config.hidden_size

hidden_states = torch.randn(batch_size, seq_length, hidden_size)
router_labels = torch.randint(0, config.num_experts, (batch_size,))

# Test the forward pass of the SentenceEnforcedSwitchMoeBlock
output = moe_block(hidden_states, router_labels)

# Print the shapes of input and output tensors
print("Input hidden states shape:", hidden_states.shape)
print("Router labels shape:", router_labels.shape)
print("Output shape:", output.shape)

In [None]:
router_labels = torch.tensor([2, 0, 1, 2, 1])
print(router_labels)
sorted_indices = torch.argsort(router_labels)

sorted_labels = router_labels[sorted_indices]

unsorted_indices = torch.argsort(sorted_indices)

router_labels = sorted_labels[unsorted_indices]
print(router_labels)

In [1]:
import torch
from utils import get_yaml
from models.load_model import load_models

def calc_memory(model, inputs, device, name='Layer'):
    model.to(device)
    try:
        inputs = inputs.to(device)
        memory_before = torch.cuda.memory_allocated(device)
        output = model(inputs)        
    except:
        inputs = (input.to(device) for input in inputs)
        memory_before = torch.cuda.memory_allocated(device)
        output = model(*inputs)

    memory_after = torch.cuda.memory_allocated(device)
    vram_usage = memory_after - memory_before
    print(f"Process {name} VRAM usage: {vram_usage / 1024 / 1024:.2f} MB")
    model.cpu()
    del model
    torch.cuda.empty_cache()


class arguments:
    yaml_path = 'yamls/MOE/moesm_double_35.yaml'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

args = arguments()
yargs = get_yaml(args.yaml_path)
for key, value in yargs['general_args'].items(): # copy yaml config into args
    setattr(args, key, value)
for key, value in yargs['training_args'].items():
    setattr(args, key, value)


args.gated = True
args.moe_type = 'topk'
args.token_moe = True
args.topk = 2
args.num_experts = 8

model, tokenizer = load_models(args) # if eval and skip, not needed

cuda


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Some weights differ
All weights match
Model loaded in  0.11 minutes
734.2 million total parameters
534.8 million effective parameters
Approximately 2.74 GB of memory in fp32

MoEsmVec(
  (base): T5EncoderModel(
    (shared): Embedding(144, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(144, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(64, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (1): T5LayerFF(
          

In [2]:
base = model.base
esm = model.esm
base_adapter = model.base_adapter
esm_adapter = model.esm_adapter
proj = model.proj

In [3]:
batch_size = 2
max_length = 1024
vocab_size = 20  # IDs from 0 to 20

base_input = torch.randint(0, 20, (batch_size, max_length))
esm_input = torch.randint(0, 20, (batch_size, max_length))
base_adapter_input = (torch.rand(batch_size, 49, max_length, 768), torch.rand(batch_size, max_length, 480))
esm_adapter_input = (torch.rand(batch_size, 49, max_length, 768), torch.rand(batch_size, 13, max_length, 480))
proj_input = torch.rand(batch_size, 768 + 480)

names = ['ANKH', 'ESM', 'BASE ADAPT', 'ESM ADAPT', 'PROJ']
models = [base, esm, base_adapter, esm_adapter, proj]
inputs = [base_input, esm_input, base_adapter_input, esm_adapter_input, proj_input]

In [4]:
for model, input, name in zip(models, inputs, names):
    calc_memory(model, input, device, name)
torch.cuda.empty_cache()

Process ANKH VRAM usage: 11.12 MB


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Process ESM VRAM usage: 2544.31 MB
Process BASE ADAPT VRAM usage: 47.97 MB
Process ESM ADAPT VRAM usage: 58.50 MB
Process PROJ VRAM usage: 0.02 MB
