# Sequence generation

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import torch
from evoamp.models import EvoAMP
from evoamp.models._globals import END_TOKEN

torch.manual_seed(0);

In [2]:
def prettify(sequence: list[list[str]]) -> list[str]:
    start, end = 1, len(sequence) if sequence[-1] != END_TOKEN else len(sequence) - 1
    return "".join(sequence[start:end])

With the pretrained model, let's generate a few sequences.

In [21]:
DATA_PATH = "../data/processed/amp_8_35.csv"
MODEL_PATH = "../outputs/2024-05-18/19-31-51/pretrained_model"

In [22]:
model = EvoAMP.load(MODEL_PATH)



In [23]:
df = pd.read_csv(DATA_PATH)

Generate new sequences based on reference sequence.

In [37]:
reference_sequence = df.iloc[0]["sequence"]
is_amp = df.iloc[1]["is_amp"]

print(f"Reference sequence: {reference_sequence}")
print(f"Is AMP: {is_amp}")

Reference sequence: TWKKGFPHGTCSKCARE
Is AMP: 0


In [38]:
sequences = model.sample(n_samples=5, is_amp=is_amp, reference_sequence=reference_sequence)

In [39]:
sequences = [prettify(seq) for seq in sequences]
sequences

['TWKKGFPHGTCSKCARE',
 'TWKKGFPHGTCSKCARE',
 'TWKKGFPHGTCSKCARE',
 'TWKKGFPHGTCSKCARE',
 'KVAIAMKKLEED']

Generate new sequences (unconditioned).

In [35]:
sequences = model.sample(n_samples=5, is_amp=is_amp, reference_sequence=None)
sequences = [prettify(seq) for seq in sequences]

In [36]:
sequences

['TWKKGFPHGTCSKCARE',
 'KVAIAMKKLEED',
 'KVAIAMKKLEED',
 'KVKIGFPHGTCSKCARE',
 'TWKKGFPHGTCSKCARE']