### GPU setting

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:

  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import os
import sys
import json
import shutil
import pickle
import random
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
from typing import List
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import requests
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
sys.path.append('../../code/Common_modules')
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
from Utils import set_randomness
set_randomness()
pd.options.display.max_colwidth = 999
SEED = 2021

### Tokenizer

In [None]:
from typing import List
from dataclasses import dataclass, asdict
from transformers import EsmTokenizer, EsmModel
from Tokenize_modules import Vocabulary, PeptideTokenizer, locate_specials, locate_non_standard_AA

ESM_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
vocab = Vocabulary(file_name = os.path.join('../../data','vocab/vocab.txt'))
peptide_tokenizer = PeptideTokenizer(vocab)

In [None]:
from RL_GPT_modules import GPTConfig
Total_species_list = ['Abaumannii','Bsubtilis','Ecoli','Efaecalis','Enterobactercloacae','Kpneumoniae','Paeruginosa','Saureus']
gpt_conf = GPTConfig(voc = vocab)
data_path = gpt_conf.data_path

### AMP dataset

In [None]:
train_amp_csv = pd.read_csv(os.path.join(data_path,"multi_train_35_0.8.csv"))
remove_idx = train_amp_csv[train_amp_csv['sequence'].str.contains('U|Z|B|X')]
AMP_train_df = train_amp_csv.drop(remove_idx.index)
AMP_finetune_seq = AMP_train_df.sequence.unique().tolist()
len(AMP_finetune_seq)

### PrefixProt MIC data

In [None]:
prefix_MIC  = pd.read_csv('../../data/Prefix_list.csv')
Prefix_Pred = prefix_MIC['Escherichia coli'].tolist()

### HemoDL (GPU 사용)

In [None]:
import sys
from features import fs_encode
import lightgbm as lgb
import numpy as np
import torch
from transformers import T5Tokenizer, T5Model,T5EncoderModel
import re
from Bio import SeqIO
import argparse

model_esm, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D")
model_esm = model_esm.to(device)

tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_uniref50',do_lower_case=False)
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50").to(device)

model_fs = lgb.Booster(model_file="../../Hemolysis_predictor/source/models/model.fs")
model_tr = lgb.Booster(model_file="../../Hemolysis_predictor/source/models/model.transformer")

### Species select

In [None]:
from MIC_predictor import get_features, RegressionModel
from Utils import get_classify, classify_AMP, ClassificationModel
genome_features = torch.load(gpt_conf.genome_feature_path)
species_35 = pd.read_csv(gpt_conf.species_path)
species = species_35['species'].unique()
genome_feats = get_features([species[0]], genome_features)

### Regression & Classification model

In [None]:
cls_model = ClassificationModel(hidden_feat = 256, pooling = 'CLS')
cls_model.load_state_dict(torch.load('../../AMP_classifier/LMPred.pth',map_location=gpt_conf.device),strict=False)
cls_model.to(device)

In [None]:
reg_model = RegressionModel(hidden_feat = 256, pooling = 'CLS')
reg_model.load_state_dict(torch.load('../../MIC_predictor/pepESM_90_35_species_500.pth',map_location=device),strict=False)
reg_model.to(device)

### **Reinforcement learning for Seen species**

In [None]:
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
from collections import Counter
from dataclasses import dataclass, asdict
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
from Levenshtein import distance
from Levenshtein import ratio

from Utils import count_parameters, estimate_and_update, get_num_amino_acid
from Tokenize_modules import sequence_to_input
from MIC_predictor import get_reward_logp
from Hemolysis_predictor import classify_Hemo
from RL_GPT_modules import GPTGenerator, GPTGeneratorConfig, BaseGPTWrapper
from RL_modules import GradientTracker, ExperienceMemory, Reinforcement

AMP_finetuned_path = '../../ckpt/Pretrain/AMP_pretrain/Finetune_Pareto_ckpt20.ckpt'
genome_feature_index_list = [7, 4, 0, 8, 13, 5, 2, 1]

for spc_idx in range(0,2):
    for times in [1, 2, 3, 4]:
        for kl_coef in [0.0001, 0.01, 0.05, 0.1]:
            RL_lr = 5e-4
            sample_num = 2000
            set_randomness()
            pd.options.display.max_colwidth = 999
            SEED = 2021
            mic_value = 1.5
            hemo_value = 0.5
            gpt_conf.reward_thres_reg = mic_value
            gpt_conf.reward_thres_hemo = hemo_value
            gpt_conf.RL_ckpt_path = f"../../ckpt/RL/{Total_species_list[spc_idx]}/Replay_buffer_Top_{times}_KL_coef_{kl_coef}/"
            os.makedirs(gpt_conf.RL_ckpt_path, exist_ok = True)
            print(f'===============MIC_threshold : {gpt_conf.reward_thres_reg}, Hemo_threshold : {gpt_conf.reward_thres_hemo}, Learning_rate : {RL_lr}, Species_name : {Total_species_list[spc_idx]}===============')

            gpt_conf.gen_samples = sample_num
            gpt_conf.n_iterations = 40
            gpt_conf.rein_opt_lr = RL_lr

            conf_1 = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
            basegpt_1 = BaseGPTWrapper(conf_1.gpt_conf)
            prior = GPTGenerator(basegpt_1, conf_1)
            prior = prior.construct_by_ckpt_dict(torch.load(AMP_finetuned_path),vocab)
            
            conf_2 = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
            basegpt_2 = BaseGPTWrapper(conf_2.gpt_conf)
            generator = GPTGenerator(basegpt_2, conf_2)
            generator = generator.construct_by_ckpt_dict(torch.load(AMP_finetuned_path),vocab)

            genome_feats = get_features([species[genome_feature_index_list[spc_idx]]], genome_features)

            n_to_generate = gpt_conf.gen_samples
            n_iterations = gpt_conf.n_iterations

            for param in prior.base_gpt.gpt.parameters():
                param.requires_grad = False

            for param in generator.base_gpt.gpt.parameters():
                param.requires_grad = False

            for target_module in generator.base_gpt.gpt.lora_layers:
                for param in generator.base_gpt.gpt.lora_layers[target_module].parameters():
                    param.requires_grad = True

            for name, param in generator.base_gpt.gpt.named_parameters():
                print(f'{name} Grad : {param.requires_grad}')

            RL_logp = Reinforcement(generator, reg_model, get_reward_logp, classify_Hemo, rein_opt_lr = gpt_conf.rein_opt_lr, genome_feats = genome_feats)
            gradient_tracker = GradientTracker(RL_logp.generator.base_gpt.gpt)

            total_params, trainable_params = count_parameters(RL_logp.generator.base_gpt.gpt)
            trainable_percent = (trainable_params / total_params)
            print(f"trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percent:.4f}")

            rewards = []
            MIC_rewards = []
            Hemo_rewards = []
            rl_losses = []

            memory = ExperienceMemory(capacity=5000)

            for i in range(n_iterations):
                print(f'===============Current Epoch : {i}===============')
                generate = np.array(RL_logp.generator.sample_decode(ssize=n_to_generate, msl=50, bs=128))
                print(generate)
                print(f'generate number : {len(list(set(generate)))}')
                generate = list(generate)
                mic_preds, _ = get_reward_logp(generate,reg_model,genome_feats)
                hemo_preds, _ = classify_Hemo(generate)

                memory.add_sequences(generate, mic_preds, hemo_preds)
                new_gen, _, _ = memory.get_buffer(sample_size = 100, times = times)
                generate = random.sample(generate, 200)
                generate = generate + new_gen

                classes, _ = classify_AMP(cls_model, generate, ESM_tokenizer)
                classes = np.array([1 if cls >= gpt_conf.reward_thres_cls else 0 for cls in classes]) 
                print(f'The ratio of generated AMP samples : {classes.sum() / len(classes)}')  

                train_data_loader = DataLoader(list(set(generate)), batch_size=gpt_conf.batch_size,
                                      shuffle=True, drop_last=True, collate_fn=None)

                cur_loss, cur_reward, MIC_mean_reward, Hemo_mean_reward = RL_logp.policy_gradient(train_data_loader, prior, epoch = i, kl_coef = kl_coef, gradient_tracker = gradient_tracker, grad_clipping = 1.0, gamma = gpt_conf.gamma)
                rewards.append(cur_reward)
                MIC_rewards.append(MIC_mean_reward)
                Hemo_rewards.append(Hemo_mean_reward)
                rl_losses.append(cur_loss)

                RL_LoRA_weight_path = gpt_conf.RL_ckpt_path+f"{Total_species_list[spc_idx]}_AMP_RL_LoRA_Weight_ES_ME_40_{gpt_conf.reward_thres_hemo}_{gpt_conf.reward_thres_reg}_{gpt_conf.rein_opt_lr}_{gpt_conf.gen_samples}_%d.pth"%(i)
                RL_logp.generator.base_gpt.gpt.save_lora_weights(RL_LoRA_weight_path)

                # Print max gradients
                gradient_tracker.print_max_grads()

                plt.plot(rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average reward')
                plt.show()
                plt.plot(MIC_rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average MIC reward')
                plt.show()
                plt.plot(Hemo_rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average Hemo reward')
                plt.show()
                plt.plot(rl_losses)
                plt.xlabel('Training iteration')
                plt.ylabel('Loss')
                plt.show()

                result = estimate_and_update(RL_logp.generator, reg_model, cls_model ,tokenizer = tokenizer,n_to_generate=1000, genome_feat = genome_feats) 
                temp = [len(data) for data in result[0]]
                counter_object = Counter(temp)
                plt.bar(list(counter_object.keys()), list(counter_object.values()))
                plt.xlabel('Length distribution of Peptides')
                plt.ylabel('Number of Samples by length')
                plt.show()

                finetune_composition = get_num_amino_acid(list(set(result[0].tolist())))
                database_composition = get_num_amino_acid(list(set(AMP_finetune_seq)))

                del finetune_composition[' ']
                del database_composition[' ']

                plt.bar(height = finetune_composition.values(), x=finetune_composition.keys(), label = 'optimized', alpha=0.3)
                plt.bar(height = database_composition.values(), x=database_composition.keys(), label = 'database', alpha=0.3)
                plt.legend()
                plt.title('Amino acid composition')
                plt.show()

                # Diversity
                LD_array = np.zeros((len(result[0].tolist()),len(result[0].tolist())))
                for row_idx, row_data in enumerate(result[0].tolist()):
                    for col_idx, col_data in enumerate(result[0].tolist()):
                        LD_array[row_idx][col_idx] = 1-ratio(row_data,col_data)

                diversity_values = np.tril(LD_array)[np.tril(LD_array).nonzero()].mean()

                # Novelty
                novelty_values = 0
                for data in list(set(result[0].tolist())):
                    if data not in train_amp_csv.sequence.tolist():
                        novelty_values += 1

                # Uniqueness
                unique_values = len(list(set(result[0].tolist())))
                print(f'Uniqueness : {unique_values}, Diversity : {diversity_values} , Novelty : {novelty_values}')

### **Inference**

In [None]:
Best_ckpt_path = '../../ckpt/RL/Best_ckpt/Seen_species'

In [None]:
AMP_pretrained_path = '../../ckpt/Pretrain/AMP_pretrain/Finetune_Pareto_ckpt20.ckpt'

In [None]:
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
from collections import Counter
from dataclasses import dataclass, asdict
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
from glob import glob
from Levenshtein import distance
from Levenshtein import ratio
from MIC_predictor import get_reward_logp
from Hemolysis_predictor import classify_Hemo
from Utils import classify_AMP, plot_hist, calculate_overlap_ratio
from RL_GPT_modules import GPTGenerator, GPTGeneratorConfig, BaseGPTWrapper
from Tokenize_modules import Vocabulary, PeptideTokenizer, locate_specials, locate_non_standard_AA

ckpt_path_list = glob(os.path.join(Best_ckpt_path, '*.pth'))

for path_idx, path in enumerate(ckpt_path_list):
    save_species_name = os.path.basename(ckpt_path_list[path_idx]).split('_')[0]

    conf = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
    basegpt = BaseGPTWrapper(conf.gpt_conf)
    generator = GPTGenerator(basegpt, conf)
    generator = generator.construct_by_ckpt_dict(torch.load(AMP_pretrained_path),vocab)
    generator.base_gpt.gpt.load_lora_weights(path)

    generator.base_gpt.gpt.eval()
    with torch.no_grad():
        sampled = generator.sample_decode(ssize=1000, msl=50, bs=20)
        display(sampled)
    generate = list(sampled)

    n_to_generate = gpt_conf.gen_samples
    ESM_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
    vocab = Vocabulary(file_name = os.path.join('../../data','vocab/vocab.txt'))
    peptide_tokenizer = PeptideTokenizer(vocab)

    nndd = [len(data) for data in generate]
    print(f'min number of generation : {min(nndd)}')
    preds,_  = get_reward_logp(generate, reg_model, genome_feature = genome_feats)
    classes, _ = classify_AMP(cls_model, generate, ESM_tokenizer)
    hemo_pred, _ = classify_Hemo(generate)
    classes = np.array([1 if data >= 0.5 else 0 for data in classes])
    micpreds = np.array([1 if data <= gpt_conf.reward_thres_reg else 0 for data in preds])
    nonhemo = np.array([1 if data <= gpt_conf.reward_thres_hemo else 0 for data in hemo_pred])
    plot_hist(preds, n_to_generate)
    print(f'The ratio of generated AMP : {classes.sum()/len(classes)}')
    print(f'The ratio of non-hemolytic generated peptides : {nonhemo.sum()/len(nonhemo)}')
    both_ratio = calculate_overlap_ratio(micpreds, nonhemo)
    print(f'The ratio of Both (Low MIC adn Low Hemolysis) peptides : {both_ratio}')
    print(f'The ratio of non-redundant peptides : {len(list(set(generate)))/len(generate)}')

    finetune_composition = get_num_amino_acid(list(set(generate)))
    database_composition = get_num_amino_acid(list(set(AMP_finetune_seq)))

    del finetune_composition[' ']
    del database_composition[' ']

    plt.bar(height = finetune_composition.values(), x=finetune_composition.keys(), label = 'optimized', alpha=0.3)
    plt.bar(height = database_composition.values(), x=database_composition.keys(), label = 'database', alpha=0.3)
    plt.legend()
    plt.title('Amino acid composition')
    plt.show()

    # Diversity
    LD_array = np.zeros((len(generate),len(generate)))
    for row_idx, row_data in enumerate(generate):
        for col_idx, col_data in enumerate(generate):
            LD_array[row_idx][col_idx] = 1-ratio(row_data,col_data)

    diversity_values = np.tril(LD_array)[np.tril(LD_array).nonzero()].mean()

    # Novelty
    novelty_values = 0
    for data in list(set(generate)):
        if data not in train_amp_csv.sequence.tolist():
            novelty_values += 1

    # Uniqueness
    unique_values = len(list(set(generate)))
    print(f'Uniqueness : {unique_values}, Diversity : {diversity_values} , Novelty : {novelty_values}')

    npy_path = '../../Generated_samples'

    np.save(os.path.join(npy_path,f'{save_species_name}_1000.npy'),np.array(sampled))

### **Reinforcement learning for Unseen species**

#### **Unseen dataset and genome features**

In [None]:
external_data = pd.read_csv(os.path.join('../../data','Unseen_data','external.csv'))
external_data

In [None]:
genome_features = torch.load(gpt_conf.genome_feature_path)
species_10 = external_data
species = species_10['species'].unique()
genome_feats = get_features([species[0]], genome_features)

#### **Training unseen species**

In [None]:
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
from collections import Counter
from dataclasses import dataclass, asdict
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
from Levenshtein import distance
from Levenshtein import ratio

from Utils import count_parameters, estimate_and_update, get_num_amino_acid
from Tokenize_modules import sequence_to_input
from MIC_predictor import get_reward_logp
from Hemolysis_predictor import classify_Hemo
from RL_GPT_modules import GPTGenerator, GPTGeneratorConfig, BaseGPTWrapper
from RL_modules import GradientTracker, ExperienceMemory, Reinforcement

AMP_finetuned_path = '../../ckpt/Pretrain/AMP_pretrain/Finetune_Pareto_ckpt20.ckpt'
Total_species_list = species_10['species'].unique()
genome_feature_index_list = range(len(Total_species_list))

for spc_idx in range(0,2):
    for times in [1, 2, 3, 4]:
        for kl_coef in [0.0001, 0.01, 0.05, 0.1]:
            RL_lr = 5e-4
            sample_num = 2000
            set_randomness()
            pd.options.display.max_colwidth = 999
            SEED = 2021
            mic_value = 1.5
            hemo_value = 0.5
            gpt_conf.reward_thres_reg = mic_value
            gpt_conf.reward_thres_hemo = hemo_value
            gpt_conf.RL_ckpt_path = f"../../ckpt/RL/{Total_species_list[spc_idx]}/Replay_buffer_Top_{times}_KL_coef_{kl_coef}/"
            os.makedirs(gpt_conf.RL_ckpt_path, exist_ok = True)
            print(f'===============MIC_threshold : {gpt_conf.reward_thres_reg}, Hemo_threshold : {gpt_conf.reward_thres_hemo}, Learning_rate : {RL_lr}, Species_name : {Total_species_list[spc_idx]}===============')

            gpt_conf.gen_samples = sample_num
            gpt_conf.n_iterations = 40
            gpt_conf.rein_opt_lr = RL_lr

            conf_1 = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
            basegpt_1 = BaseGPTWrapper(conf_1.gpt_conf)
            prior = GPTGenerator(basegpt_1, conf_1)
            prior = prior.construct_by_ckpt_dict(torch.load(AMP_finetuned_path),vocab)
            
            conf_2 = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
            basegpt_2 = BaseGPTWrapper(conf_2.gpt_conf)
            generator = GPTGenerator(basegpt_2, conf_2)
            generator = generator.construct_by_ckpt_dict(torch.load(AMP_finetuned_path),vocab)

            genome_feats = get_features([species[genome_feature_index_list[spc_idx]]], genome_features)

            n_to_generate = gpt_conf.gen_samples
            n_iterations = gpt_conf.n_iterations

            for param in prior.base_gpt.gpt.parameters():
                param.requires_grad = False

            for param in generator.base_gpt.gpt.parameters():
                param.requires_grad = False

            for target_module in generator.base_gpt.gpt.lora_layers:
                for param in generator.base_gpt.gpt.lora_layers[target_module].parameters():
                    param.requires_grad = True

            for name, param in generator.base_gpt.gpt.named_parameters():
                print(f'{name} Grad : {param.requires_grad}')

            RL_logp = Reinforcement(generator, reg_model, get_reward_logp, classify_Hemo, rein_opt_lr = gpt_conf.rein_opt_lr, genome_feats = genome_feats)
            gradient_tracker = GradientTracker(RL_logp.generator.base_gpt.gpt)

            total_params, trainable_params = count_parameters(RL_logp.generator.base_gpt.gpt)
            trainable_percent = (trainable_params / total_params)
            print(f"trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percent:.4f}")

            rewards = []
            MIC_rewards = []
            Hemo_rewards = []
            rl_losses = []

            memory = ExperienceMemory(capacity=5000)

            for i in range(n_iterations):
                print(f'===============Current Epoch : {i}===============')
                generate = np.array(RL_logp.generator.sample_decode(ssize=n_to_generate, msl=50, bs=128))
                print(generate)
                print(f'generate number : {len(list(set(generate)))}')
                generate = list(generate)
                mic_preds, _ = get_reward_logp(generate,reg_model,genome_feats)
                hemo_preds, _ = classify_Hemo(generate)

                memory.add_sequences(generate, mic_preds, hemo_preds)
                new_gen, _, _ = memory.get_buffer(sample_size = 100, times = times)
                generate = random.sample(generate, 200)
                generate = generate + new_gen

                classes, _ = classify_AMP(cls_model, generate, ESM_tokenizer)
                classes = np.array([1 if cls >= gpt_conf.reward_thres_cls else 0 for cls in classes]) 
                print(f'The ratio of generated AMP samples : {classes.sum() / len(classes)}')  

                train_data_loader = DataLoader(list(set(generate)), batch_size=gpt_conf.batch_size,
                                      shuffle=True, drop_last=True, collate_fn=None)

                cur_loss, cur_reward, MIC_mean_reward, Hemo_mean_reward = RL_logp.policy_gradient(train_data_loader, prior, epoch = i, kl_coef = kl_coef, gradient_tracker = gradient_tracker, grad_clipping = 1.0, gamma = gpt_conf.gamma)
                rewards.append(cur_reward)
                MIC_rewards.append(MIC_mean_reward)
                Hemo_rewards.append(Hemo_mean_reward)
                rl_losses.append(cur_loss)

                RL_LoRA_weight_path = gpt_conf.RL_ckpt_path+f"{Total_species_list[spc_idx]}_AMP_RL_LoRA_Weight_ES_ME_40_{gpt_conf.reward_thres_hemo}_{gpt_conf.reward_thres_reg}_{gpt_conf.rein_opt_lr}_{gpt_conf.gen_samples}_%d.pth"%(i)
                RL_logp.generator.base_gpt.gpt.save_lora_weights(RL_LoRA_weight_path)

                # Print max gradients
                gradient_tracker.print_max_grads()

                plt.plot(rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average reward')
                plt.show()
                plt.plot(MIC_rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average MIC reward')
                plt.show()
                plt.plot(Hemo_rewards)
                plt.xlabel('Training iteration')
                plt.ylabel('Average Hemo reward')
                plt.show()
                plt.plot(rl_losses)
                plt.xlabel('Training iteration')
                plt.ylabel('Loss')
                plt.show()

                result = estimate_and_update(RL_logp.generator, reg_model, cls_model ,tokenizer = tokenizer,n_to_generate=1000, genome_feat = genome_feats) 
                temp = [len(data) for data in result[0]]
                counter_object = Counter(temp)
                plt.bar(list(counter_object.keys()), list(counter_object.values()))
                plt.xlabel('Length distribution of Peptides')
                plt.ylabel('Number of Samples by length')
                plt.show()

                finetune_composition = get_num_amino_acid(list(set(result[0].tolist())))
                database_composition = get_num_amino_acid(list(set(AMP_finetune_seq)))

                del finetune_composition[' ']
                del database_composition[' ']

                plt.bar(height = finetune_composition.values(), x=finetune_composition.keys(), label = 'optimized', alpha=0.3)
                plt.bar(height = database_composition.values(), x=database_composition.keys(), label = 'database', alpha=0.3)
                plt.legend()
                plt.title('Amino acid composition')
                plt.show()

                # Diversity
                LD_array = np.zeros((len(result[0].tolist()),len(result[0].tolist())))
                for row_idx, row_data in enumerate(result[0].tolist()):
                    for col_idx, col_data in enumerate(result[0].tolist()):
                        LD_array[row_idx][col_idx] = 1-ratio(row_data,col_data)

                diversity_values = np.tril(LD_array)[np.tril(LD_array).nonzero()].mean()

                # Novelty
                novelty_values = 0
                for data in list(set(result[0].tolist())):
                    if data not in train_amp_csv.sequence.tolist():
                        novelty_values += 1

                # Uniqueness
                unique_values = len(list(set(result[0].tolist())))
                print(f'Uniqueness : {unique_values}, Diversity : {diversity_values} , Novelty : {novelty_values}')

### **Inference**

In [None]:
Best_ckpt_path = '../../ckpt/RL/Best_ckpt/Unseen_species'

In [None]:
AMP_pretrained_path = '../../ckpt/Pretrain/AMP_pretrain/Finetune_Pareto_ckpt20.ckpt'

In [None]:
from tqdm import tqdm_notebook
from matplotlib import pyplot as plt
from collections import Counter
from dataclasses import dataclass, asdict
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
from glob import glob
from Levenshtein import distance
from Levenshtein import ratio
from MIC_predictor import get_reward_logp
from Hemolysis_predictor import classify_Hemo
from Utils import classify_AMP, plot_hist, calculate_overlap_ratio
from RL_GPT_modules import GPTGenerator, GPTGeneratorConfig, BaseGPTWrapper
from Tokenize_modules import Vocabulary, PeptideTokenizer, locate_specials, locate_non_standard_AA

ckpt_path_list = glob(os.path.join(Best_ckpt_path, '*.pth'))

for path_idx, path in enumerate(ckpt_path_list):
    save_species_name = os.path.basename(ckpt_path_list[path_idx]).split('_')[0]

    conf = GPTGeneratorConfig(gpt_conf=GPTConfig(voc = vocab),lr_mult=0.95)
    basegpt = BaseGPTWrapper(conf.gpt_conf)
    generator = GPTGenerator(basegpt, conf)
    generator = generator.construct_by_ckpt_dict(torch.load(AMP_pretrained_path),vocab)
    generator.base_gpt.gpt.load_lora_weights(path)

    generator.base_gpt.gpt.eval()
    with torch.no_grad():
        sampled = generator.sample_decode(ssize=1000, msl=50, bs=20)
        display(sampled)
    generate = list(sampled)

    n_to_generate = gpt_conf.gen_samples
    ESM_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
    vocab = Vocabulary(file_name = os.path.join('../../data','vocab/vocab.txt'))
    peptide_tokenizer = PeptideTokenizer(vocab)

    nndd = [len(data) for data in generate]
    print(f'min number of generation : {min(nndd)}')
    preds,_  = get_reward_logp(generate, reg_model, genome_feature = genome_feats)
    classes, _ = classify_AMP(cls_model, generate, ESM_tokenizer)
    hemo_pred, _ = classify_Hemo(generate)
    classes = np.array([1 if data >= 0.5 else 0 for data in classes])
    micpreds = np.array([1 if data <= gpt_conf.reward_thres_reg else 0 for data in preds])
    nonhemo = np.array([1 if data <= gpt_conf.reward_thres_hemo else 0 for data in hemo_pred])
    plot_hist(preds, n_to_generate)
    print(f'The ratio of generated AMP : {classes.sum()/len(classes)}')
    print(f'The ratio of non-hemolytic generated peptides : {nonhemo.sum()/len(nonhemo)}')
    both_ratio = calculate_overlap_ratio(micpreds, nonhemo)
    print(f'The ratio of Both (Low MIC adn Low Hemolysis) peptides : {both_ratio}')
    print(f'The ratio of non-redundant peptides : {len(list(set(generate)))/len(generate)}')

    finetune_composition = get_num_amino_acid(list(set(generate)))
    database_composition = get_num_amino_acid(list(set(AMP_finetune_seq)))

    del finetune_composition[' ']
    del database_composition[' ']

    plt.bar(height = finetune_composition.values(), x=finetune_composition.keys(), label = 'optimized', alpha=0.3)
    plt.bar(height = database_composition.values(), x=database_composition.keys(), label = 'database', alpha=0.3)
    plt.legend()
    plt.title('Amino acid composition')
    plt.show()

    # Diversity
    LD_array = np.zeros((len(generate),len(generate)))
    for row_idx, row_data in enumerate(generate):
        for col_idx, col_data in enumerate(generate):
            LD_array[row_idx][col_idx] = 1-ratio(row_data,col_data)

    diversity_values = np.tril(LD_array)[np.tril(LD_array).nonzero()].mean()

    # Novelty
    novelty_values = 0
    for data in list(set(generate)):
        if data not in train_amp_csv.sequence.tolist():
            novelty_values += 1

    # Uniqueness
    unique_values = len(list(set(generate)))
    print(f'Uniqueness : {unique_values}, Diversity : {diversity_values} , Novelty : {novelty_values}')

    npy_path = '../../Generated_samples'

    np.save(os.path.join(npy_path,f'{save_species_name}_1000.npy'),np.array(sampled))