In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import BertModel, BertTokenizer
import torch
import re
import numpy as np
from Ampep.toxic_func import toxic_feature
from Ampep.amp_func import amp_feature
import random


In [None]:

tokenizer = T5Tokenizer.from_pretrained('prot_t5_xl_bfd')
model = T5ForConditionalGeneration.from_pretrained('prot_t5_xl_bfd').to('cuda')


In [None]:
from Bio import pairwise2

def similarity_score(seq1, seq2):
    alignments = pairwise2.align.globalxx(seq1,seq2)
    lens = max(len(seq1), len(seq2))
    try:
        return  -alignments[0].score/lens
    except Exception as e:
        print(e)
        print(alignments)
        print(seq1)
        print(seq2)
        return 0

In [None]:
def estimate_gradient(z, q, beta, criterion, sigma=100):
    z_dim = z.shape[1:]
    u = np.random.normal(0, sigma, size=(q, z_dim[0],z_dim[1])).astype('float32')
    u = torch.from_numpy(u / np.linalg.norm(u, axis=1, keepdims=True)).to(device='cuda')

    f_0 = criterion(z)
    f_tmp = criterion(z + beta*u)
    print('Loss now: %f'%(f_0[0]))
    # print(f_0)
    u = u.to(device='cpu')
    # print(u.device)
    return torch.mean(z_dim[1] * u * np.expand_dims(np.expand_dims(f_tmp - f_0, 1),1)/ beta, dim=0,
                      keepdims=True).to(dtype = z.dtype, device = 'cuda')

In [None]:
def model_encode(seq):
    mask_rate = 0.15
    len_seq = len(seq)
    mask_len = int(len_seq * mask_rate)
    mask_idx = random.sample(range(len_seq), mask_len)
    
    raw_input = list(seq)
    for i,idx in enumerate(mask_idx):
        raw_input[idx] = '<extra_id_%d>'%i
    raw_input = [' '.join(raw_input)]
    inputs = tokenizer(raw_input, return_tensors='pt')['input_ids'].to('cuda')
    r1 = model.encoder.forward(inputs)['last_hidden_state']
    return r1

def model_decode(emb):
    outputs = [0]
    for i in range(0, 100):
        out = model.decoder.forward(torch.tensor([outputs], device='cuda'), encoder_hidden_states= emb)
        out = model.lm_head(out['last_hidden_state'])
        out = torch.softmax(out, dim = -1)
        out = torch.argmax(out, dim = -1)
        outputs.append(int(out[0][i]))
        if out[0][-1] == 1:
            break
        if outputs[-1] < 3 or outputs[-1] > 22:
            outputs[-1] = 23
        if i >= int(emb.shape[1] * 1.2):
            break
    seq = tokenizer.decode(outputs[1:-1])
    seq = ''.join(seq.split())
    return seq

In [None]:
def loss_function(z, model,origin_seq ,weight=1, score=None, constraints=[],
                  weight_constraint=False):

    res = []
    num = z.shape[0]
    for i in range(num):
        seq = model_decode(torch.unsqueeze(z[i], dim=0))

        loss_property = score(seq, origin_seq) if score else 0

        loss_constraint = 0
        for c in constraints:
            loss_constraint += c(seq)

        loss =  (loss_property + loss_constraint*weight if weight_constraint else
            loss_property*weight + loss_constraint)
        res.append(loss)
    return np.array(res)

In [None]:
from functools import partial
import time
from tqdm import tqdm
def optimize(model, seq, q=100, base_lr=0.1, max_iter=1000, num_restarts=1,
             weight=0.1, beta=1, use_adam=False, early_stop=False, score=similarity_score,
             constraints=[amp_feature, toxic_feature], writer=None, run_str=None, results_dir='results',
             init_best={}, write_log=None, flip_weight=False):
    z_0 = model_encode(seq)     #获得序列的embedding
    print(seq)
    loss = partial(loss_function, model=model, origin_seq = seq,weight=weight, score=score,
                   constraints=[amp_feature, toxic_feature], weight_constraint=flip_weight)
    best = {'score': -np.inf, 'found': False, 'early_stop': False}
    best.update(init_best)
    
    for k in range(num_restarts):
        if best['early_stop']:
            break
        z = z_0.clone()
        for i in (range(max_iter)):
            print('start itr %d'%i)
            grad = estimate_gradient(z, q, beta, loss)  # 使用QMO计算离散梯度
            if use_adam:
                z.grad = grad
            else:
                lr = ((1 - i/max_iter)**0.5) * base_lr
                z -= grad * lr
            
            mol = model_decode(z)   # 将优化后的embedding还原为序列
            print('After optim ', mol)
            mol_score = score and -score(mol, seq)
            print('score is %f'%mol_score)
            print('AMP : %d, Toxic: %d'%(constraints[0](mol), constraints[1](mol)))
            if (score is None or mol_score > best['score']) and all(c(mol) == 0 for c in constraints):
                # best.update(desc)
                print('Bingo!')
                best.update(dict(step=i, z=z, z_0=z_0, seq=mol,
                                 score=mol_score, found=True, run=k,
                                  early_stop=early_stop))

                print(f'PASSED!')

                if early_stop:
                    break
            print()
            
    if not best['found']:
        print('Search failed!')
    return best


In [None]:
seq = 'WFHHIFRGIVHVGKTIHRLVTG'
optimize(model,seq)

In [None]:
seq = 'AKKVFKRLGIGAVLWVLTTG'
optimize(model,seq)