# LSTM generator in PyTorch using label encoding

In this notebook we demonstrate the application of `peptidy` in generating antimicrobial peptides (AMPs) using a Long Short-Term Memory Network (LSTM) set up in PyTorch. Label encoding is applied to the amino acid sequences before they are fed into the model, this is done using `peptidy`.



In [1]:
!pip install peptidy

Collecting peptidy
  Downloading peptidy-0.0.1-py3-none-any.whl.metadata (5.1 kB)
Downloading peptidy-0.0.1-py3-none-any.whl (21 kB)
Installing collected packages: peptidy
Successfully installed peptidy-0.0.1


In [2]:
from typing import Dict, List
import pandas as pd

import torch
from torch import nn
from torch.nn import functional as F

import peptidy


### Load a dataframe with peptides

In [3]:
csv_url = 'https://raw.githubusercontent.com/AryaVenkatesh2010/AryaAIProject/refs/heads/main/subsample_AMP.csv'
subsample_AMP = pd.read_csv(csv_url)

X=subsample_AMP.drop('active',axis=1)
y=subsample_AMP['active']

### Split the data into training and validation sets

In [4]:
# Only select the active peptides
active_peptides = X["sequence"][y==1].tolist()
pad_len = max([len(peptide) for peptide in active_peptides]) + 2        # +2 for start and end tokens

n_training = 20

training_peptides = active_peptides[:n_training]
val_peptides = active_peptides[n_training:]


In [5]:
active_peptides

['GVLDILKNAAKNILAHAAEQI',
 'KKCKFFCKVKKKIKSIGFQIPIVSIPFK',
 'SWLSKTAKKLENSAKKRISEGIAIAIQGGPR',
 'GIWSSIKNLASKAWNSDIGQSLRNKAAGAINKFVADKIGVTPSQAAS',
 'WSCPTLSGVCRKVCLPTEMFFGPLGCGKEFQCCVSHFF',
 'IWSFLIKAATKLLPSLFGGGKKDS',
 'GFFALIPKIISSPLFKTLLSAVGSALSSSGEQE',
 'ALKAALLAILKIVRVIKK',
 'KRGLWESLKRKATKLGDDIRNTLRNFKIKFPVPRQG',
 'KWKVFKKIEKMGRNIRNGIVKAGPAIAVLGEAKAILS',
 'GWGSFFKKAAHVGKHVGKAALTHYL',
 'QQCGRQAGNRRCANNLCCSQYGYCGRTNEYCCTSQGCQSQCRRCG',
 'FIPGLRRLFATVVPTVVCAINKLPPG',
 'AKKVFKRLEKLFSKIQNDK',
 'GLPVCGETCVGGTCNTPGCTCSWPVCTRN',
 'IDWLKLGKMVMDVL',
 'MNFLKNGIAKWMTGAELQAYKKKYGCLPWEKISC',
 'LRDLVCYCRTRGCKRRERMNGTCRKGHLMYTLCCR',
 'VGRKHSILNCIPYLKKKKIMRL',
 'ASHLGHHALDHLLK',
 'GIFSKLGRKKIKNLLISGLKNVGKEVGMDVVRTGIDIAGCKIKGEC',
 'FLPAALAGIGGILGKLF',
 'GPDSCNHDRGLCRVGNCNPGEYLAKYCFEPVILCCKPLSPTPTKT',
 'CEWYNISCQLGNKGQWCTLTKECQRSCK',
 'HHHLFGHVGHEVERSLHKVGHKLEHACHEVHKTAKKVQK']

In [6]:
training_peptides

['GVLDILKNAAKNILAHAAEQI',
 'KKCKFFCKVKKKIKSIGFQIPIVSIPFK',
 'SWLSKTAKKLENSAKKRISEGIAIAIQGGPR',
 'GIWSSIKNLASKAWNSDIGQSLRNKAAGAINKFVADKIGVTPSQAAS',
 'WSCPTLSGVCRKVCLPTEMFFGPLGCGKEFQCCVSHFF',
 'IWSFLIKAATKLLPSLFGGGKKDS',
 'GFFALIPKIISSPLFKTLLSAVGSALSSSGEQE',
 'ALKAALLAILKIVRVIKK',
 'KRGLWESLKRKATKLGDDIRNTLRNFKIKFPVPRQG',
 'KWKVFKKIEKMGRNIRNGIVKAGPAIAVLGEAKAILS',
 'GWGSFFKKAAHVGKHVGKAALTHYL',
 'QQCGRQAGNRRCANNLCCSQYGYCGRTNEYCCTSQGCQSQCRRCG',
 'FIPGLRRLFATVVPTVVCAINKLPPG',
 'AKKVFKRLEKLFSKIQNDK',
 'GLPVCGETCVGGTCNTPGCTCSWPVCTRN',
 'IDWLKLGKMVMDVL',
 'MNFLKNGIAKWMTGAELQAYKKKYGCLPWEKISC',
 'LRDLVCYCRTRGCKRRERMNGTCRKGHLMYTLCCR',
 'VGRKHSILNCIPYLKKKKIMRL',
 'ASHLGHHALDHLLK']

In [7]:
val_peptides

['GIFSKLGRKKIKNLLISGLKNVGKEVGMDVVRTGIDIAGCKIKGEC',
 'FLPAALAGIGGILGKLF',
 'GPDSCNHDRGLCRVGNCNPGEYLAKYCFEPVILCCKPLSPTPTKT',
 'CEWYNISCQLGNKGQWCTLTKECQRSCK',
 'HHHLFGHVGHEVERSLHKVGHKLEHACHEVHKTAKKVQK']

### Define dataloader, using peptidy to tokenize the sequences

In [10]:
class PeptideLoader(torch.utils.data.Dataset):
    def __init__(
        self,
        label_encoded_peptides: torch.LongTensor,
    ):
        self.label_encoded_peptides = label_encoded_peptides

    def __len__(self):
        return self.label_encoded_peptides.shape[0]

    def __getitem__(self, idx):
        peptide = self.label_encoded_peptides[idx, :]
        X = peptide[:-1]
        y = peptide[1:]
        return X, y


def tokenize_peptides(
    peptides: List[str],
    padding_length: int,
) -> torch.LongTensor:
    token_to_label = peptidy.biology.token_to_label
    tokenized_peptides = [peptidy.encoding.label_encoding(peptide, padding_len=padding_length, add_generative_tokens=True) for peptide in peptides]

    return torch.LongTensor(tokenized_peptides), token_to_label

def get_dataloader(
    peptides: List[str],
    padding_length: int,
    batch_size: int,
    shuffle: bool = True,
):
    peptides_tensor, token_to_label = tokenize_peptides(peptides, padding_length)

    return torch.utils.data.DataLoader(
        PeptideLoader(peptides_tensor),
        batch_size=batch_size,
        shuffle=shuffle,
    )

### Define the model

In [11]:
class LSTM(nn.Module):
    def __init__(
        self,
        model_dim: int,
        n_layers: int,
        vocab_size: int,
        sequence_length: int,
        learning_rate: float,
        n_epochs: int,
        batch_size: int,
        device: str,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.model_dim = model_dim
        self.vocab_size = vocab_size
        self.sequence_length = sequence_length
        self.learning_rate = learning_rate
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.device = device

        self.architecture = self.build_architecture()

    def build_architecture(self):
        return nn.ModuleDict(
            dict(
                embedding = nn.Embedding(self.vocab_size, self.model_dim, padding_idx=0),
                lstm = nn.LSTM(self.model_dim, self.model_dim, self.n_layers, batch_first=True),
                lm_head = nn.Linear(self.model_dim, self.vocab_size),
            )
        )

    def forward(
        self,
        x: torch.LongTensor,
        hidden_states: torch.FloatTensor = None,
        training: bool = True,
    ):
        if len(x.shape) == 1:
            x = x.unsqueeze(1)                                      # (batch_size, 1, seq_len)
        x = self.architecture.embedding(x)                          # (batch_size, seq_len, model_dim)
        x, hidden_states = self.architecture.lstm(x, hidden_states) # (batch_size, seq_len, model_dim)
        x = self.architecture.lm_head(x)                            # (batch_size, seq_len, vocab_size)
        if training:
            return x

        return x, hidden_states

    def __compute_loss(
        self,
        inputs: torch.LongTensor,
        targets: torch.LongTensor,
    ) -> torch.Tensor:
        logits = self.forward(inputs, training=True)
        logits = logits.permute(0, 2, 1)
        return F.cross_entropy(logits, targets.long())

    def fit(
        self,
        training_peptides: List[str],
        val_peptides: List[str],
    ):

        self = self.to(self.device)
        self.train()

        train_dataloader = get_dataloader(training_peptides, self.sequence_length, self.batch_size)
        val_dataloader = get_dataloader(val_peptides, self.sequence_length, self.batch_size)

        optimizer = torch.optim.Adam(self.parameters(), self.learning_rate)
        history = {"train_loss": list(), "val_loss": list()}

        for epoch_ix in range(self.n_epochs):
            self.train()

            n_train_samples, epoch_train_loss = 0, 0
            for X_train, y_train in train_dataloader:
                X_train, y_train = X_train.to(self.device), y_train.to(self.device)
                n_train_samples += X_train.shape[0]

                optimizer.zero_grad()
                batch_train_loss = self.__compute_loss(X_train, y_train)
                batch_train_loss.backward()
                optimizer.step()

                epoch_train_loss += batch_train_loss.item() * X_train.shape[0]

            epoch_train_loss = epoch_train_loss / n_train_samples
            history["train_loss"].append(epoch_train_loss)

            self.eval()
            n_val_samples, epoch_val_loss = 0, 0
            for X_val, y_val in val_dataloader:
                X_val, y_val = X_val.to(self.device), y_val.to(self.device)
                n_val_samples += X_val.shape[0]

                batch_val_loss = self.__compute_loss(X_val, y_val)
                epoch_val_loss += batch_val_loss.item() * X_val.shape[0]

            epoch_val_loss = epoch_val_loss / n_val_samples
            history["val_loss"].append(epoch_val_loss)
            print(f"Epoch {epoch_ix} | Train loss: {epoch_train_loss} | Val loss: {epoch_val_loss}")

        return history

    def initialize_hidden_states(self, batch_size: int):
        return (
            torch.zeros(self.n_layers, batch_size, self.model_dim)
            .float()
            .to(self.device),
            torch.zeros(self.n_layers, batch_size, self.model_dim)
            .float()
            .to(self.device),
        )


    def design_peptides(
        self,
        n_batches: int,
        batch_size: int,
        temperature: float,
        token_to_label: Dict[str, int],
        begin_token: str,
        end_token: str,
    ):
        self = self.to(self.device)
        self.eval()
        label_to_token = {v: k for k,v in token_to_label.items()}
        designs = list()

        for _ in range(n_batches):
            hidden_states = self.initialize_hidden_states(batch_size)
            current_token = torch.zeros(
                batch_size,
            )
            current_token = token_to_label[begin_token] + current_token.long().to(self.device)

            batch_designs = list()
            for __ in range(self.sequence_length - 1):
                preds, hidden_states = self.forward(current_token, hidden_states, training=False)
                preds = preds.squeeze(1)
                preds = preds / temperature
                probs = F.softmax(preds, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
                batch_designs.append(next_token)

                current_token = next_token

            batch_designs = torch.vstack(batch_designs).T
            designs.append(batch_designs)

        designs = torch.cat(designs, dim=0).cpu().numpy().tolist()
        end_index = token_to_label[end_token]
        designs = [
            design[ : design.index(end_index)] if end_index in design else ""
            for design in designs
        ]

        final_designs = list()
        for design in designs:
            if design == "":
                final_designs.append("")
            else:
                final_designs.append("".join([label_to_token[label] for label in design]))

        return final_designs

### Train the model

In [12]:
# Retrieve token_to_label dictionary from peptidy
token_to_label = peptidy.biology.token_to_label

# Define model parameters
model_dim = 156
n_layers = 2
vocab_size = len(token_to_label) + 3                # +3 for start, end, and padding tokens
learning_rate = 0.01
n_epochs = 20
batch_size = 32

# Create model
lstm = LSTM(model_dim, n_layers, vocab_size, pad_len, learning_rate, n_epochs, batch_size, "cpu")

# Fit model
history = lstm.fit(training_peptides, val_peptides)

Epoch 0 | Train loss: 3.4500434398651123 | Val loss: 3.1861982345581055
Epoch 1 | Train loss: 3.1032631397247314 | Val loss: 3.142003059387207
Epoch 2 | Train loss: 2.605785608291626 | Val loss: 2.6979241371154785
Epoch 3 | Train loss: 2.396714448928833 | Val loss: 2.482426643371582
Epoch 4 | Train loss: 2.155555009841919 | Val loss: 2.3532662391662598
Epoch 5 | Train loss: 1.9754966497421265 | Val loss: 2.2878787517547607
Epoch 6 | Train loss: 1.8772571086883545 | Val loss: 2.252659797668457
Epoch 7 | Train loss: 1.8247278928756714 | Val loss: 2.232822895050049
Epoch 8 | Train loss: 1.7894967794418335 | Val loss: 2.2209079265594482
Epoch 9 | Train loss: 1.7641886472702026 | Val loss: 2.2141366004943848
Epoch 10 | Train loss: 1.741996169090271 | Val loss: 2.2081873416900635
Epoch 11 | Train loss: 1.7219895124435425 | Val loss: 2.203568696975708
Epoch 12 | Train loss: 1.7032519578933716 | Val loss: 2.1977310180664062
Epoch 13 | Train loss: 1.6843616962432861 | Val loss: 2.19364643096923

### Generate peptides

In [13]:
n_batches = 4
batch_size = 4
temperature = 0.8

token_to_label = peptidy.biology.token_to_label.copy()
token_to_label["<PAD>"] = 0
token_to_label["<BEG>"] = len(token_to_label)
token_to_label["<END>"] = len(token_to_label)

designs = lstm.design_peptides(n_batches, batch_size, temperature, token_to_label, begin_token="<BEG>", end_token="<END>")
print(designs)

['', 'GLDKFGLL', '', '', 'GVCGMRCQTCGYCGRRNCTT', '', 'WWSRWMVRWNKKKIISILKNFAAAIVMVF', 'ACVKFKS', 'FW', '', '', 'SWSWDAHPKKAKKHLLIAHLKSAKKIVGKPLALLTGSVEPRALDIV', 'LCSRRRDRCRRRR', '', 'KWFSGANGQSVTQCGGCIKAKNLLAII', 'VLGVGLVGCKVWESALENIFKTAIVKSLKHPLELG']
