In [None]:
import torch
import esm
from Bio import SeqIO
import numpy as np

In [None]:
data=[]
data_long=[]
for n, record in enumerate(SeqIO.parse("../../fig1/result/drllps_scaffold_clstr_Homo_sapiens.fasta", "fasta")):
    if len(record.seq)<1000:
        data.append((record.id,record.seq))
    else:
        data_long.append((record.id,record.seq))

In [None]:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.wrap import enable_wrap, wrap

In [None]:
# init the distributed world with world_size 1
url = "tcp://localhost:23456"
torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0)

In [None]:
model_name = "esm2_t36_3B_UR50D"
model_data, regression_data = esm.pretrained._download_model_and_regression_data(model_name)

In [None]:
# initialize the model with FSDP wrapper
fsdp_params = dict(
    mixed_precision=True,
    flatten_parameters=True,
    state_dict_device=torch.device("cpu"),  # reduce GPU mem usage
    cpu_offload=True,  # enable cpu offloading
)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
    model, vocab = esm.pretrained.load_model_and_alphabet_core(
        model_name, model_data, regression_data
    )
    batch_converter = vocab.get_batch_converter()
    model.eval()

    # Wrap each layer in FSDP separately
    for name, child in model.named_children():
        if name == "layers":
            for layer_name, layer in child.named_children():
                wrapped_layer = wrap(layer)
                setattr(child, layer_name, wrapped_layer)
    model = wrap(model)

In [None]:
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.to("cuda:0")

In [None]:
results=[]
for i in range(batch_tokens.shape[0]):
    with torch.no_grad():
        results.append(model(batch_tokens[i][None], repr_layers=[36], return_contacts=True)["representations"][36])
    if (i+1)%10==0:
        print(i+1, "done")

In [None]:
sequence_representations = []
for i, (_, seq) in enumerate(data):
    sequence_representations.append(results[i][0, 1 : len(seq) + 1].mean(0))

In [None]:
output={x:y.cpu().numpy() for x,y in zip(batch_labels,sequence_representations)}

In [None]:
import numpy as np
np.save("../esm2_3b_human_scaffold_short.npy", output) 