In [None]:
%run local_prot_vec.py

In [None]:
%run main.py --yaml_path yamls/protein_vec.yaml

In [None]:
import torch

def are_models_equal(model1, model2):
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()
    
    if len(state_dict1) != len(state_dict2):
        print('Length')
        return False
    
    for key in state_dict1:
        if key not in state_dict2:
            print('Key')
            return False
        
        if not torch.equal(state_dict1[key], state_dict2[key]):
            print('Equal')
            return False
    
    return True


if are_models_equal(disk_model.cpu(), hf_model.cpu()):
    print("The models are equal.")
else:
    print("The models are not equal.")

In [None]:
from data.load_data import get_datasets_test_triplet
from transformers import T5Tokenizer

In [None]:
args = {
    'data_paths': ['lhallee/triplets'],
    'domains': ['[EC]', '[CC]', '[MF]', '[BP]', '[CC]', '[CC]', '[IP]'],
    'new_special_tokens': False,
    'max_length':512, 
    'p_col': 'positives',
    'a_col': 'anchors',
    'n_col': 'negatives',
    'label_col': 'aspects',
    'model_type': 'ProteinVec'
}

In [None]:
tokenizer = T5Tokenizer.from_pretrained('lhallee/ProteinVec')

In [None]:
triplet_datasets = get_datasets_test_triplet(args, tokenizer)

In [None]:
ec = triplet_datasets[0]
ec

In [None]:
ec[1][4]

In [None]:
import torch
import torch.nn as nn
from models.modeling_moesm import EsmExpert


class SentenceEnforcedSwitchMoeBlock(nn.Module): ### Test
    def __init__(self, config, expert):
        """
        Sentence level MoE, single expert chosen
        """
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.num_experts = config.num_experts
        self.experts = nn.ModuleList([expert(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor, router_labels: torch.tensor) -> torch.Tensor:
        # (batch, seq_len, hidden_size), (batch,) -> from 0 to num_experts-1
        sorted_indices = torch.argsort(router_labels) # sort in order of expert idx
        hidden_states = hidden_states[sorted_indices] # apply sort
        router_labels = router_labels[sorted_indices] # apply sort
        expert_idxs = torch.unique(router_labels) # find all experts needed
        grouped_hidden_states = torch.split(hidden_states, tuple(torch.bincount(router_labels))) # split sorted hidden_states

        expert_outputs = []
        for idx, group in zip(expert_idxs, grouped_hidden_states):
            expert_output = self.experts[idx](group) # sne batched groups to their experts
            expert_outputs.append(expert_output)

        concatenated_outputs = torch.cat(expert_outputs, dim=0)
        final_hidden_states = concatenated_outputs[torch.argsort(sorted_indices)] # put back to original order
        return final_hidden_states  # (batch, sequence_length, hidden_dim)
    
class Config:
    def __init__(self):
        self.hidden_size = 128
        self.num_experts = 4
        self.intermediate_size = 256
        self.hidden_dropout_prob = 0.0


config = Config()
moe_block = SentenceEnforcedSwitchMoeBlock(config, EsmExpert)




In [None]:
# Generate random tensors for testing
batch_size = 8
seq_length = 16
hidden_size = config.hidden_size

hidden_states = torch.randn(batch_size, seq_length, hidden_size)
router_labels = torch.randint(0, config.num_experts, (batch_size,))

# Test the forward pass of the SentenceEnforcedSwitchMoeBlock
output = moe_block(hidden_states, router_labels)

# Print the shapes of input and output tensors
print("Input hidden states shape:", hidden_states.shape)
print("Router labels shape:", router_labels.shape)
print("Output shape:", output.shape)

In [None]:
router_labels = torch.tensor([2, 0, 1, 2, 1])
print(router_labels)
sorted_indices = torch.argsort(router_labels)

sorted_labels = router_labels[sorted_indices]

unsorted_indices = torch.argsort(sorted_indices)

router_labels = sorted_labels[unsorted_indices]
print(router_labels)

In [4]:
from itertools import combinations_with_replacement

batch_size = 100
for i, j in combinations_with_replacement(range(batch_size), 2):
    print(i, j)

0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
0 11
0 12
0 13
0 14
0 15
0 16
0 17
0 18
0 19
0 20
0 21
0 22
0 23
0 24
0 25
0 26
0 27
0 28
0 29
0 30
0 31
0 32
0 33
0 34
0 35
0 36
0 37
0 38
0 39
0 40
0 41
0 42
0 43
0 44
0 45
0 46
0 47
0 48
0 49
0 50
0 51
0 52
0 53
0 54
0 55
0 56
0 57
0 58
0 59
0 60
0 61
0 62
0 63
0 64
0 65
0 66
0 67
0 68
0 69
0 70
0 71
0 72
0 73
0 74
0 75
0 76
0 77
0 78
0 79
0 80
0 81
0 82
0 83
0 84
0 85
0 86
0 87
0 88
0 89
0 90
0 91
0 92
0 93
0 94
0 95
0 96
0 97
0 98
0 99
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 20
1 21
1 22
1 23
1 24
1 25
1 26
1 27
1 28
1 29
1 30
1 31
1 32
1 33
1 34
1 35
1 36
1 37
1 38
1 39
1 40
1 41
1 42
1 43
1 44
1 45
1 46
1 47
1 48
1 49
1 50
1 51
1 52
1 53
1 54
1 55
1 56
1 57
1 58
1 59
1 60
1 61
1 62
1 63
1 64
1 65
1 66
1 67
1 68
1 69
1 70
1 71
1 72
1 73
1 74
1 75
1 76
1 77
1 78
1 79
1 80
1 81
1 82
1 83
1 84
1 85
1 86
1 87
1 88
1 89
1 90
1 91
1 92
1 93
1 94
1 95
1 96
1 97
1 98
1 99
2 2
2 3
2 4
2 5
2 6
2 7
