In [117]:
# Import Statments:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Adding Device Management:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and set as device.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available and set as device.")
else:
    device = torch.device("cpu")
    print("Using CPU device.")


MPS is available and set as device.


In [77]:
amps = []
non_amps = []
peptides = []

# Curating General Peptide Sequences from PeptideAtlas:
with open('APD_Hs_all.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for aa in line:
                if aa != '\n':
                    peptides.append(aa)
            peptides.append('<s>')   

# Tokenizing Amino Acids:
amino_acids = sorted(list(set(peptides)))
vocab_size = len(amino_acids)

amino_acids


['<s>',
 'A',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'K',
 'L',
 'M',
 'N',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'V',
 'W',
 'Y']

In [78]:
# Curating Antimicrobial Peptide Sequences from dbAMP:
with open('dbAMP3.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for c in line:
                if c in amino_acids:
                    amps.append(c)
                peptides.append('<s>')  

# Curating Non-Antimicrobial Peptide Sequences from UniProt:
with open('uniprotkb_non_amp.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for c in line:
                if c in amino_acids:
                    non_amps.append(c)
            peptides.append('<s>')  

non_amps = non_amps[:len(amps)]

len(peptides), len(amps), len(non_amps)


(6849106, 1412087, 1412087)

In [79]:
# Creating Vocabularies:
itoaa = {aa:i for aa,i in enumerate(amino_acids)}
aatoi = {i:aa for aa,i in enumerate(amino_acids)}

encode = lambda l: [aatoi[aa] for aa in l]
decode = lambda l: ''.join([itoaa[i] for i in l])

# Creating Training/Validation Split:
n = int(0.9 * len(peptides))
m = int(0.9 * len(amps))

# General Peptide Tensors:
peptide_data = torch.tensor(encode(peptides), dtype=torch.long).to(device)

peptide_train_data = peptide_data[:n]
peptide_val_data = peptide_data[n:]

# AMP Tensors:
amp_data = torch.tensor(encode(amps), dtype=torch.long).to(device)

amp_train_data = amp_data[:m]
amp_val_data = amp_data[m:]

# non-AMP Tensors:
non_amp_data = torch.tensor(encode(non_amps), dtype=torch.long).to(device)

non_amp_train_data = non_amp_data[:m]
non_amp_val_data = non_amp_data[m:]

len(peptide_train_data), len(peptide_val_data), len(amp_train_data), len(amp_val_data), len(non_amp_train_data), len(non_amp_val_data)


(6164195, 684911, 1270878, 141209, 1270878, 141209)

In [80]:
# Creating Hyperparameters:
n_embd = 128
head_size = 16
n_layer = 4
n_head = 4
batch_size = 32
block_size = 128
dropout = 0.2

# Single Head of Attention:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()

        # K,Q,V Matrices:
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Buffer Matrix and Dropout Layer:
        self.register_buffer('tril', torch.tril(torch.ones([block_size, block_size])))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        
        k = self.key(x)
        q = self.query(x)

        # Determining Affinities with Weighted Sum:
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Adjusting Embeddings With Value Matrix:
        v = self.value(x)
        out = wei @ v
        return out

# Parralelization of Attention Heads:
class MultiHeadedAttention(nn.Module):

    def __init__(self, head_size, n_head):
        super().__init__()

        # List of Attention Heads:
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])

        # Projection and Dropout Layers:
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# Multi-Layer Perceptron:
class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()

        # Linear Layers, GELU and Dropout:
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.GELU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# Self-Attention/MLP Block:
class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head

        # Self-Attention/MLP:
        self.sa = MultiHeadedAttention(head_size, n_head)
        self.ffwd = FeedForward(n_embd)

        # Layer Normalization:
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    # Residual Blocks and Self-Attention/MLP:
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# AMP Transformer Model:
class AMPTransformer(nn.Module):

    def __init__(self):
        super().__init__()

        # Token and Positional Embedding Tables:
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # Block Layers:
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])

        # Layer Normalization and Unembedding:
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # Projection Head for Contrastive Learning:
        self.projection_head = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.GELU(),
            nn.Linear(n_embd, n_embd // 2),
        )

    # Create Embeddings for Contrastive Learning:
    def get_embeddings(self, idx):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)

        sequence_emb = torch.mean(x, dim=1)

        contrastive_emb = self.projection_head(sequence_emb)
        contrastive_emb = F.normalize(contrastive_emb, p=2, dim=1)
            
        return contrastive_emb

    def forward(self, idx, targets=None):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        # Determining Loss via Cross Entropy:
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx):

        # Generate New Data Until the Delimiter Token is Encountered:
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_new = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_new], dim=1)
            if idx_new == 0:
                break

        return idx

def contrastive_loss(embeddings_amp, embeddings_non_amp, temperature=0.07):
    batch_size = embeddings_non_amp.size(0)
    
    all_embeddings = torch.cat([embeddings_amp, embeddings_non_amp], dim=0)
    
    # Create Similarity Matrix:
    similarity_matrix = all_embeddings @ all_embeddings.T / temperature
    
    # Remove Similarities:
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=embeddings_amp.device)
    similarity_matrix = similarity_matrix.masked_fill(mask, float('-inf'))
    
    # Create Target Tensor:
    targets = torch.cat([torch.arange(batch_size, 2*batch_size), 
                        torch.arange(0, batch_size)], dim=0).to(embeddings_amp.device)
    
    # Return Contrastive Loss
    loss = F.cross_entropy(similarity_matrix, targets)
    
    return loss
        
# Initializing Model:
model = AMPTransformer()
model = model.to(device)
model = torch.compile(model)

# Creating Optimizer:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Untrained Model Generation:
idx = torch.zeros([1, 1], dtype=torch.long).to(device)
print(decode(model.generate(idx)[0].tolist()))


<s>NMSWGMPCDMHFWNQHSSLT<s>


In [81]:
# Batching General Peptides:
def get_general_batch(split):
    data = peptide_train_data if split == 'train' else peptide_val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return xb, yb

# Batching AMP and non-AMP Data:
def get_contrast_batch(split):
    
    # AMP Batch:
    data = amp_train_data if split == 'train' else amp_val_data
    ix = torch.randint(len(data) - block_size, (batch_size // 2,))
    amp_sequence = torch.stack([data[i:i+block_size] for i in ix])
    amp_sequence = torch.stack([data[i+1:i+block_size+1] for i in ix])

    # non-AMP Batch:
    data = amp_train_data if split == 'train' else amp_val_data
    ix = torch.randint(len(data) - block_size, (batch_size // 2,))
    non_amp_sequence = torch.stack([data[i:i+block_size] for i in ix])
    non_amp_sequence = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return amp_sequence, non_amp_sequence

# Pre-Training with General Peptides:
for steps in range(10):

    xb, yb = get_general_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 1 == 0:
        print(loss.item())


3.267407178878784
3.089454412460327
2.9115753173828125
2.981278657913208
2.939939498901367
2.684147357940674
2.9535725116729736
2.5156350135803223
2.6766574382781982
2.3052587509155273


In [82]:
# Contrastive Learning with AMP and non-AMP Data:
for steps in range(10):

    # Creating AMP and non-AMP Batches:
    amp_sequence, non_amp_sequence = get_contrast_batch('train')

    # Creating Embeddings:
    amp_embeddings = model.get_embeddings(amp_sequence)
    non_amp_embeddings = model.get_embeddings(non_amp_sequence)

    # Calculate Contrastive Loss:
    loss = contrastive_loss(amp_embeddings, non_amp_embeddings, 0.07)
        
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 1 == 0:
        print(loss.item())


3.4765772819519043
3.5255374908447266
3.3934414386749268
3.524482250213623
3.4421675205230713
3.4216322898864746
3.440786838531494
3.4349803924560547
3.4344992637634277
3.4298102855682373


In [129]:
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

# Curating MIC Data from GRAMPA:
url = "https://raw.githubusercontent.com/zswitten/Antimicrobial-Peptides/refs/heads/master/data/grampa.csv"
df = pd.read_csv(url)

# Curating Sequences:
MIC_sequences = df['sequence'].values

# Curating Values:
MIC_values = df['value'].values
MIC_values = np.power(10, MIC_values)
MIC_values = torch.tensor(np.log(MIC_values), dtype=torch.float32).to(device)

# Curating Sequence/Value Pairs:
MIC_data = []

for index, value in enumerate(MIC_values[:2]):
    tokenized_sequence = torch.tensor(encode(MIC_sequences[index]), dtype=torch.long).to(device)
    MIC_data.append([tokenized_sequence, value])

# Creating MIC Data Loader:
MIC_loader = DataLoader(MIC_data, batch_size=batch_size, shuffle=True) 

MIC_data[0]


[tensor([ 6, 10, 13, 15,  9,  8, 10,  2,  1,  8,  1,  9,  9,  9,  6,  9,  2,  9,
          6, 13, 10,  9, 10, 18,  2,  9,  2], device='mps:0'),
 tensor(-0.9163, device='mps:0')]

In [130]:
class AMPTransformerMIC(nn.Module):
    def __init__(self, transformer_model, transformer_output_dim):
        super(AMPTransformerMIC, self).__init__()
        
        # Existing Transformer Mode:
        self.backbone = transformer_model
        
        # Regression Head for MIC Prediction:
        self.mic_regression_head = nn.Linear(transformer_output_dim, 1)

    def forward(self, x):
        transformer_output, _ = self.backbone(x)
        pooled_output = transformer_output.mean(dim=1)
        mic_prediction = self.mic_regression_head(pooled_output)
        return mic_prediction

MIC_model = AMPTransformerMIC(model, vocab_size).to(device)

# Freezing Original Transformer Parameters:
for param in MIC_model.backbone.parameters():
    param.requires_grad = False

# Creating Optimizer:
optimizer = torch.optim.Adam(MIC_model.mic_regression_head.parameters(), lr=3e-4)

# Creating Criterion:
criterion = nn.MSELoss()

# Training Loop:
steps = 1

for step in range(steps):

    for batch, (x, y) in enumerate(MIC_loader):

        MIC_pred = MIC_model(x)

        loss = criterion(MIC_pred, y)

print(loss.item())

RuntimeError: stack expects each tensor to be equal size, but got [42] at entry 0 and [27] at entry 1