In [None]:
import sys
import os
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
sys.path.append('/home/yaosenmin/SFM_framework/')

from pathlib import Path
from finetune_pfm_v2 import init_model, load_batched_dataset, multi_label_transform
from sfm.utils.move_to_device import move_to_device
from sfm.data.prot_data.dataset import FoundationModelDataset
from sfm.data.prot_data.collater import pad_1d_unsqueeze
from sfm.data.prot_data.vocalubary import Alphabet
from Bio import SeqIO
from tqdm import tqdm
from joblib import Parallel, delayed
import random

%matplotlib inline
%config InlineBackend.figure_format = "svg"

In [27]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
args = Namespace()
args.task_name = "solubility"
args.max_length = 2048
args.base_model = 'pfm'

# model
args.encoder_layers=33
args.encoder_embed_dim=1280
args.encoder_ffn_embed_dim=5120
args.encoder_attention_heads=20
args.num_pred_attn_layer=2
args.atom_loss_coeff=1.0
args.pos_loss_coeff=1.0
args.sandwich_ln=True
args.ft=True
args.dropout=0
args.fp16=True
args.attn_dropout=0
args.act_dropout=0
args.weight_decay=0.0
args.droppath_prob=0.0
args.max_num_aa=2048
args.noise_mode='diff'
args.noise_scale=0.2
args.mask_ratio=0.2
args.mode_prob=1.0,0.0,0.0
args.d_tilde=1.0
args.max_lr="1e-5"
args.strategy="DDP"
args.pipeline_model_parallel_size=0
args.train_batch_size=64
args.val_batch_size=61
args.max_tokens=6400
args.max_length=2048
args.train_data_path="None"
args.valid_data_path="None"
args.data_basepath="/blob/data/bfm_benchmark"
args.task_name="solubility"
# args.loadcheck_path="/home/yaosen/blob/pfmexp/output/bfm650m_maskspan1_ddp4e5d8mask020drop1L1536_pairv4_bert2_128V100_adam2/checkpoints/checkpoint_E45.pt"
# args.loadcheck_path="/blob/pfmexp/output/bfm650m_maskspan3_ddp4e5d16mask020drop1L1536B2k_bpepairv4_bert2_128A100_adam2/checkpoints/checkpoint_E63.pt"
args.loadcheck_path = "/blob/pfmexp/output/finetune/finetune-solubility_seed21_lr2e-5_E50_bs32_bfm650m_maskspan1_ddp4e5d8mask020drop1L1536_pairv4_bert2_128V100_adam2/checkpoint_E3.pt"
# emm
args.add_3d=False
args.num_3d_bias_kernel=0
args.no_2d=False
args.rank=0
args.num_residues=32
args.t_timesteps=0
args.ddpm_schedule='cosine'
args.ddpm_beta_start=0.0
args.ddpm_beta_end=0.0
args.head_dropout=0.0
args.early_stopping_patience=0
args.early_stopping_metric='loss'
args.early_stopping_mode='min'
args.grad_scaler_init = 1.0


In [28]:
class FastaDataset(FoundationModelDataset):
    def __init__(self, args) -> None:
        super().__init__()
        self.args = self.set_default_args(args)
        self.vocab = Alphabet()
        self.seqs = []
        self.names = []
        for record in SeqIO.parse(args.fasta_file, "fasta"):
            self.names.append(record.description)
            self.seqs.append(str(record.seq).upper())

    def set_default_args(self, args):
        if not hasattr(args, "max_length"):
            args.max_length = 2048
        if not hasattr(args, "fasta_file"):
            raise ValueError("Please specify fasta_file")
        
    def __getitem__(self, index: int) -> dict:
        item = {"id": index, 'name': self.names[index], "aa": self.seqs[index]}
        tokens = [self.vocab.tok_to_idx[tok] for tok in item["aa"]]
        if self.vocab.prepend_bos:
            tokens.insert(0, self.vocab.cls_idx)
        if self.vocab.append_eos:
            tokens.append(self.vocab.eos_idx)
        item["aa"] = np.array(tokens, dtype=np.int64)
        return item

    def __len__(self) -> int:
        return len(self.seqs)

    def size(self, index: int) -> int:
        return len(self.seqs[index])

    def num_tokens(self, index: int) -> int:
        return len(self.seqs[index]) + 2

    def num_tokens_vec(self, indices):
        raise NotImplementedError()

    def collate(self, samples: list) -> dict:
        max_tokens = max(len(s["aa"]) for s in samples)
        batch = dict()

        batch["id"] = torch.tensor([s["id"] for s in samples], dtype=torch.long)
        batch["naa"] = torch.tensor([len(s["aa"]) for s in samples], dtype=torch.long)
        batch["name"] = [s["name"] for s in samples]
        # (Nres+2,) -> (B, Nres+2)
        batch["x"] = torch.cat(
            [
                pad_1d_unsqueeze(
                    torch.from_numpy(s["aa"]), max_tokens, 0, self.vocab.padding_idx
                )
                for s in samples
            ]
        )
        return batch


In [29]:
def embed_fn(rank, world_size, args, load_ckpt, batches):
    device = f"cuda:{rank % world_size}"
    model = init_model(args, load_ckpt=load_ckpt)

    # load downstream ckpt
    checkpoints_state = torch.load(args.loadcheck_path, map_location="cpu")
    if "model" in checkpoints_state:
        checkpoints_state = checkpoints_state["model"]
    elif "module" in checkpoints_state:
        checkpoints_state = checkpoints_state["module"]

    IncompatibleKeys = model.load_state_dict(checkpoints_state, strict=False)
    IncompatibleKeys = IncompatibleKeys._asdict()
    print(f"checkpoint: {args.loadcheck_path} is loaded")
    print(f"Following keys are incompatible: {IncompatibleKeys.keys()}")
    # end

    model.to(device)
    model.eval()
    embeds, targets = [], []
    with torch.no_grad():
        for batch in tqdm(batches, ncols=80, desc=f"Rank {rank}"):
            batch = move_to_device(batch, device)
            # B, L, D
            # embed = model.model.ft_forward(batch)
            embed = model(batch)
            # seq_mask = ~(batch["x"].eq(0) | batch["x"].eq(1) | batch["x"].eq(2))
            # seq_length = seq_mask.sum(dim=1)
            # embed[seq_mask] = 0.0
            # embed = embed.sum(dim=1) / seq_length.unsqueeze(dim=1)
            embeds.append(embed.cpu().numpy())
            targets.append(batch["name"])
            # targets.append(batch['target'].squeeze())
    return np.concatenate(embeds, axis=0), np.concatenate(targets, axis=0)

# SCOPe embedding

In [None]:
args.fasta_file = '/blob/data/astral-scopedom-seqres-sel-gs-bib-40-2.08.fa'
scope = FastaDataset(args)
scope_loader = torch.utils.data.DataLoader(
    scope,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=scope.collate,
    pin_memory=True,
    drop_last=False,
)
world_size = 1
# all_batches = [batch for batch in tqdm(scope_loader, ncols=80, total=len(scope_loader))]

# chunks = lambda l, n: [l[i : i + n] for i in range(0, len(l), n)]
# batch_chunks = list(chunks(all_batches, len(all_batches) // world_size + 1))
# print(f"{len(all_batches)}, {sum([len(c) for c in batch_chunks])}")

# result = Parallel(n_jobs=world_size)(
#     delayed(embed_fn)(rank, world_size, args, True, batch)
#     for rank, batch in enumerate(scope_loader)
# )

embeddings, descriptions = embed_fn(0, 1, args, True, scope_loader)

# embeddings = np.concatenate([r[0] for r in result], axis=0)
# descriptions = np.concatenate([r[1] for r in result], axis=0)
# np.savez('scope_embed.npz', embeddings=embeddings, descriptions=descriptions)


classes = []
classes_dict = {
    "a": "a: All alpha",
    "b": "b: All beta",
    "c": "c: Alpha & beta (a/b)",
    "d": "d: Alpha & beta (a+b)",
    "e": "e: Multi-domain",
    "f": "f: Membrane, cell surface",
    "g": "g: Small proteins",
}
for d in descriptions:
    item = d.split(" ")[1]
    classes.append(classes_dict[item[0]])

from collections import Counter
print(Counter(classes))
print(embeddings.shape)

# mask = [c[0] in ['a', 'b', 'c', 'g'] for c in classes]
# sub_embed = embeddings[mask]
# sub_classes = [c for idx, c in enumerate(classes) if mask[idx]]

import umap
import umap.plot
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, verbose=1, metric='nan_euclidean', )#perplexity=10, early_exaggeration=100,)#perplexity=50)
tsne_embed = tsne.fit_transform(embeddings)
reducer = umap.UMAP(n_jobs=64, n_neighbors=15, min_dist=0.0, repulsion_strength=5.0,)
mapper = reducer.fit(embeddings)
umap_embed = reducer.transform(embeddings)

ax = umap.plot.points(mapper, points=tsne_embed, labels=np.array(classes), color_key_cmap='tab10',)


ax.set_title(f'TSNE clustering SCOPe', fontsize=12)
ax.texts[0].set_visible(False)

# solubility embedding

In [35]:
args.fasta_file = '/home/yaosenmin/high.fa'
solu = FastaDataset(args)
solu_loader = torch.utils.data.DataLoader(
    solu,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=solu.collate,
    pin_memory=True,
)

pred, label = embed_fn(0, 1, args, True, solu_loader)
pred, label = pred.squeeze(), label.squeeze()


# train_dataset, valid_dataset, test_dataset_dict = load_batched_dataset(args)
# test_loader = torch.utils.data.DataLoader(
#     test_dataset_dict['test'],
#     batch_size=16,
#     shuffle=False,
#     num_workers=0,
#     collate_fn=test_dataset_dict['test'].collate,
#     pin_memory=True,
# )

# world_size = 16
# all_batches = [batch for batch in tqdm(test_loader, ncols=80, total=len(test_loader))]

# chunks = lambda l, n: [l[i : i + n] for i in range(0, len(l), n)]
# batch_chunks = list(chunks(all_batches, len(all_batches) // world_size + 1))
# print(f"{len(all_batches)}, {sum([len(c) for c in batch_chunks])}")

# pred = np.concatenate([r[0] for r in result], axis=0).squeeze()
# label = np.concatenate([r[1] for r in result], axis=0)

if finetune: True
checkpoint: /blob/pfmexp/output/finetune/finetune-solubility_seed21_lr2e-5_E50_bs32_bfm650m_maskspan1_ddp4e5d8mask020drop1L1536_pairv4_bert2_128V100_adam2/checkpoint_E3.pt is loaded
Following keys are incompatible: dict_keys(['missing_keys', 'unexpected_keys'])


Rank 0: 100%|█████████████████████████████████████| 7/7 [00:13<00:00,  1.95s/it]


In [45]:
train_dataset, valid_dataset, test_dataset_dict = load_batched_dataset(args)

idx2tok = {v: k for k, v in train_dataset.vocab.tok_to_idx.items()}
with open('/home/yaosenmin/solubility_train.csv', 'w') as f:
    print('label,seq', file=f)
    for item in train_dataset:
        target = item['target'].item()
        seq = ''.join([idx2tok[idx] for idx in item['aa'] if idx not in [0, 1, 2]])
        print(f'{target},{seq}', file=f)




In [36]:
import seaborn as sns
import pandas as pd
sigmoid = lambda z: 1/(1 + np.exp(-z))
prob = sigmoid(pred)
df = pd.DataFrame({"pred": prob, "label": label})
sns.boxplot(x="label", y="pred", data=df)

In [40]:
df[df['label'] == 'high'].describe()

Unnamed: 0,pred
count,100.0
mean,0.457914
std,0.142577
min,0.061834
25%,0.362893
50%,0.455072
75%,0.558226
max,0.763785


In [41]:
df[df['label'] == 'low'].describe()

Unnamed: 0,pred
count,100.0
mean,0.457914
std,0.142577
min,0.061834
25%,0.362893
50%,0.455072
75%,0.558226
max,0.763785


# AA embedding

In [None]:
args.fasta_file = '/home/yaosenmin/aa.fa'
aa_data = FastaDataset(args)
aa_loader = torch.utils.data.DataLoader(
    aa_data,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn=aa_data.collate,
    pin_memory=True,
)
world_size = 1
all_batches = [batch for batch in tqdm(aa_loader, ncols=80, total=len(aa_loader))]

chunks = lambda l, n: [l[i : i + n] for i in range(0, len(l), n)]
batch_chunks = list(chunks(all_batches, len(all_batches) // world_size + 1))
print(f"{len(all_batches)}, {sum([len(c) for c in batch_chunks])}")

result = Parallel(n_jobs=world_size)(
    delayed(embed_fn)(rank, world_size, args, True, batch)
    for rank, batch in enumerate(batch_chunks)
)

aa_dict = {a: arr for a, arr in zip("ARNDBCEQZGHILKMFPSTWYVXOU", result[0][0])}
np.savez("aa.npz", **aa_dict)

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 14 22:59:05 2019

@author: mheinzinger
"""
from pathlib import Path
import numpy as np
import seaborn as sns; sns.set(); sns.set_style("whitegrid")
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


def plot_tsne( data, fig_path, SEQ, MARKER, CLASSES, SIZES ):

    colors  = sns.color_palette("Paired", len(SEQ))
    colors = [ colors[i] for i in range(1,len(SEQ), 2)]
    
    fig, ax = plt.subplots()
    ax.grid(False)
        
    for idx, AA in enumerate( SEQ ):
        # x,y coords for t-sne and marker
        x  = data[idx, 0]
        y  = data[idx, 1]
        mark = MARKER[idx]
        size = SIZES[AA]
        
        if size < 130:
            size=60
        elif size > 130 and size < 150:
            size=150
        else:
            size=300
        
        # color
        AA_class = CLASSES[ mark ][0]
        label    = CLASSES[ mark ][1]
        color = np.expand_dims( np.asarray( colors[AA_class]), axis=0 )
        if AA == 'C':
            mark = 'v'
            AA_class = CLASSES[ mark ][0]
            label    = CLASSES[ mark ][1]
            color = np.expand_dims( np.asarray( colors[AA_class]), axis=0 )
            sns.scatterplot( x=[x], y=[y-1], marker=mark, label=label, s=size/2, color=color, ax=ax, linewidth=0 ) # Needs to be adjusted as well if labels are off. Originally 13
            mark = '^'
            AA_class = CLASSES[ mark ][0]
            label    = CLASSES[mark ][1]
            color = np.expand_dims( np.asarray( colors[AA_class]), axis=0 )
            sns.scatterplot( x=[x], y=[y+1], marker=mark, label=label, s=size/2, color=color, ax=ax, linewidth=0 )
        else:
            sns.scatterplot( x=[x], y=[y], marker=mark, label=label, s=size, color=color, ax=ax )
        plt.text(x+1, y+1, AA, fontsize=14) # TODO: needs to be adjusted if labels are off. Originally, 15
        
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    
    handles, labels = plt.gca().get_legend_handles_labels()
    
    # manually define a new patch 
    legend_elements = [ Line2D([0], [0], marker='o', color='black', label='Small',  markersize= 60**(1/2), linestyle='None'),
                        Line2D([0], [0], marker='o', color='black', label='Medium', markersize=150**(1/2), linestyle='None'),
                        Line2D([0], [0], marker='o', color='black', label='Big',    markersize=300**(1/2), linestyle='None') ]
    
    #handles is a list, so append manual patch
    handles += legend_elements
    labels  += ['Small (<130 Dalton)','Medium','Big (>150 Dalton)']
    

    by_label = dict( zip(labels, handles) )
    #by_label = sorted(by_label.items())
    label_sorting = [ 'Hydrophobic (aromatic)',
                      'Hydrophobic (aliphatic)',
                      'Positive',               
                      'Negative',                
                      'Polar neutral',           
                      'Special cases',                    
                      'Small (<130 Dalton)',                   
                      'Medium',                  
                      'Big (>150 Dalton)',                     
                      ]
    
    labels, handles = list(), list()
    for label in label_sorting:
        labels.append(label)
        handles.append( by_label[label] )
    
    #by_label = list( zip(*by_label) )
    
    lgd = ax.legend( handles, labels,
                    loc='upper left', bbox_to_anchor=(0., -0.03 ), ncol=2, frameon=False, 
                         borderaxespad=0., markerscale=1 )
    for lh in lgd.legendHandles:  # removes opacity from legend
        try:
            lh.set_sizes([100.0])
        except AttributeError:
            continue
    
    plt.show()

    # write figures as PDF to disk
    fig.savefig( fig_path, format='pdf', bbox_inches='tight' ) 
    plt.close(fig) # close figure handle
    return None


def get_tsne_rep( data, perp, n_iter ):
    from sklearn.metrics import pairwise_distances
    distance_matrix = pairwise_distances( data, data, metric='cosine', n_jobs=-1)

    from sklearn.manifold import TSNE
    trafo_data = TSNE(    n_components= 2, 
                          perplexity  = perp,
                          init        = 'random', 
                          random_state= 42, 
                          n_iter      = n_iter,
                          verbose     = 1,
                          metric      = 'precomputed'
                          ).fit_transform(distance_matrix)
    
    return trafo_data


def remove_ambigious(aa_embd, SEQ, MARKER):
    
    def _remove_idx( aa_embd, string_1, string_2, to_be_removed):
        idx = string_1.index(to_be_removed)
        string_1 = string_1[:idx] + string_1[idx+1:]
        string_2 = string_2[:idx] + string_2[idx+1:]
        aa_embd= np.vstack( (aa_embd[:idx], aa_embd[idx+1:]))
        return aa_embd, string_1, string_2
    
    for non_std_aa in "ZBOU":
        aa_embd, SEQ, MARKER = _remove_idx( aa_embd, SEQ, MARKER, non_std_aa )

    return aa_embd, SEQ, MARKER
    
    
def aa_plot( root_dir, fig_path, perp, n_iter, use_standard_aas=True ):
    SEQ     = "ARNDBCEQZGHILKMFPSTWYVXOU" # 20 standard AAs + rare + ambigious (B,Z)
    MARKER  = "vPoXodXoodPvvPv^doo^^vooo"
    CLASSES = { 'v': ( 0, 'Hydrophobic (aliphatic)' ),
                '^': ( 1, 'Hydrophobic (aromatic)'),
                'P': ( 2, 'Positive'),
                'X': ( 3, 'Negative'),
                'o': ( 4, 'Polar neutral'),
                'd': ( 5, 'Special cases') 
                }

    SIZES = { 'A' : 89, 
              'R' : 174,
              'N' : 132,
              'D' : 133,
              'B' : 133,
              'C' : 121,
              'E' : 147,
              'Q' : 146,
              'Z' : 133,
              'G' : 75,
              'H' : 155,
              'I' : 131,
              'L' : 131,
              'K' : 146,
              'M' : 149,
              'F' : 165,
              'P' : 115,
              'S' : 105,
              'T' : 119,
              'W' : 204,
              'Y' : 181,
              'V' : 117,
              'X' : 133,
              'O' : 133,
              'U' : 133
              }
    
    #npz_path = root_dir / 'single_aas.npz'
    npz_path = Path('aa.npz') # TODO: insert path to npz with single amino acids as keys and embeddings as values here
    aa_embd  = dict(np.load( npz_path, mmap_mode='r'))
    
    if len(aa_embd.keys()) > 1:
        tmp_embd = list()
        for aa in SEQ:
            tmp_embd.append( aa_embd[aa] )
        aa_embd = dict()
        aa_embd['>single_aas'] = np.vstack(tmp_embd)
    
    aa_embd  = next(iter(aa_embd.values()))[:len(SEQ),:]
    if use_standard_aas:
        aa_embd, SEQ, MARKER = remove_ambigious(aa_embd, SEQ, MARKER)
    
    print(aa_embd.shape)
    aa_tsne  = get_tsne_rep( aa_embd, perp, n_iter )
    print(aa_tsne.shape)
    plot_tsne( aa_tsne, fig_path, SEQ, MARKER, CLASSES, SIZES )
    return None


def main():
    root_dir = Path.cwd()
    fig_root = root_dir # / 'electra_discriminator' / 'fig_single_aas'

    # perplexity for t-SNE plots
    perp = 3
    n_iter =15000
    
    fig_path = fig_root / 'tsne_aa_perp{}_niter{}_withX_single_aas_bert_BFD.pdf'.format(perp, n_iter)
    aa_plot( root_dir, fig_path, perp, n_iter, use_standard_aas=True )
    return None


main()