In [None]:
%run local_prot_vec.py

In [None]:
%run main.py --yaml_path yamls/protein_vec.yaml

In [None]:
import torch

def are_models_equal(model1, model2):
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()
    
    if len(state_dict1) != len(state_dict2):
        print('Length')
        return False
    
    for key in state_dict1:
        if key not in state_dict2:
            print('Key')
            return False
        
        if not torch.equal(state_dict1[key], state_dict2[key]):
            print('Equal')
            return False
    
    return True


if are_models_equal(disk_model.cpu(), hf_model.cpu()):
    print("The models are equal.")
else:
    print("The models are not equal.")

In [None]:
from data.load_data import get_datasets_test_triplet
from transformers import T5Tokenizer

In [None]:
args = {
    'data_paths': ['lhallee/triplets'],
    'domains': ['[EC]', '[CC]', '[MF]', '[BP]', '[CC]', '[CC]', '[IP]'],
    'new_special_tokens': False,
    'max_length':512, 
    'p_col': 'positives',
    'a_col': 'anchors',
    'n_col': 'negatives',
    'label_col': 'aspects',
    'model_type': 'ProteinVec'
}

In [None]:
tokenizer = T5Tokenizer.from_pretrained('lhallee/ProteinVec')

In [None]:
triplet_datasets = get_datasets_test_triplet(args, tokenizer)

In [None]:
ec = triplet_datasets[0]
ec

In [None]:
ec[1][4]

In [None]:
import torch
import torch.nn as nn
from models.modeling_moesm import EsmExpert


class SentenceEnforcedSwitchMoeBlock(nn.Module): ### Test
    def __init__(self, config, expert):
        """
        Sentence level MoE, single expert chosen
        """
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.num_experts = config.num_experts
        self.experts = nn.ModuleList([expert(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor, router_labels: torch.tensor) -> torch.Tensor:
        # (batch, seq_len, hidden_size), (batch,) -> from 0 to num_experts-1
        sorted_indices = torch.argsort(router_labels) # sort in order of expert idx
        hidden_states = hidden_states[sorted_indices] # apply sort
        router_labels = router_labels[sorted_indices] # apply sort
        expert_idxs = torch.unique(router_labels) # find all experts needed
        grouped_hidden_states = torch.split(hidden_states, tuple(torch.bincount(router_labels))) # split sorted hidden_states

        expert_outputs = []
        for idx, group in zip(expert_idxs, grouped_hidden_states):
            expert_output = self.experts[idx](group) # sne batched groups to their experts
            expert_outputs.append(expert_output)

        concatenated_outputs = torch.cat(expert_outputs, dim=0)
        final_hidden_states = concatenated_outputs[torch.argsort(sorted_indices)] # put back to original order
        return final_hidden_states  # (batch, sequence_length, hidden_dim)
    
class Config:
    def __init__(self):
        self.hidden_size = 128
        self.num_experts = 4
        self.intermediate_size = 256
        self.hidden_dropout_prob = 0.0


config = Config()
moe_block = SentenceEnforcedSwitchMoeBlock(config, EsmExpert)




In [None]:
# Generate random tensors for testing
batch_size = 8
seq_length = 16
hidden_size = config.hidden_size

hidden_states = torch.randn(batch_size, seq_length, hidden_size)
router_labels = torch.randint(0, config.num_experts, (batch_size,))

# Test the forward pass of the SentenceEnforcedSwitchMoeBlock
output = moe_block(hidden_states, router_labels)

# Print the shapes of input and output tensors
print("Input hidden states shape:", hidden_states.shape)
print("Router labels shape:", router_labels.shape)
print("Output shape:", output.shape)

In [None]:
router_labels = torch.tensor([2, 0, 1, 2, 1])
print(router_labels)
sorted_indices = torch.argsort(router_labels)

sorted_labels = router_labels[sorted_indices]

unsorted_indices = torch.argsort(sorted_indices)

router_labels = sorted_labels[unsorted_indices]
print(router_labels)