In [46]:
# Import Statments:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Adding Device Management:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and set as device.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available and set as device.")
else:
    device = torch.device("cpu")
    print("Using CPU device.")


MPS is available and set as device.


In [52]:
peptides = []
amps = []
non_amps = []

# Reading Generalized Peptide File from PeptideAtlas:
with open('APD_Hs_all.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for c in line:
                    peptides.append(c)

# Tokenizing Amino Acids:
amino_acids = sorted(list(set(peptides)))
vocab_size = len(chars)

# Reading AMP File from dbAMP:
with open('dbAMP3.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for c in line:
                if c in amino_acids:
                    amps.append(c)

# Reading non-AMP File from UniProt:
with open('uniprotkb_non_amp.fasta', 'r', encoding='utf-8') as file:
    for line in file:
        if not line.startswith('>'):
            for c in line:
                if c in amino_acids:
                    non_amps.append(c)

non_amps = non_amps[:len(amps)]

len(peptides), len(amps), len(non_amps)


(5284250, 1447687, 1447687)

In [76]:
itoaa = {aa:i for aa,i in enumerate(amino_acids)}
aatoi = {i:aa for aa,i in enumerate(amino_acids)}

encode = lambda l: [aatoi[aa] for aa in l]
decode = lambda l: ''.join([itoaa[i] for i in l])

# Creating Training/Validation Split:
n1 = int(0.9 * len(peptides))
n2 = int(0.9 * len(amps))

# Generalized Peptide Data:

peptide_data = torch.tensor(encode(peptides), dtype=torch.long).to(device)

peptide_train_data = peptide_data[:n1]
peptide_val_data = peptide_data[n1:]

# AMP Data:
amp_data = torch.tensor(encode(amps), dtype=torch.long).to(device)

amp_train_data = amp_data[:n2]
amp_val_data = amp_data[n2:]

# non-AMP Data:
non_amp_data = torch.tensor(encode(non_amps), dtype=torch.long).to(device)

non_amp_train_data = non_amp_data[:n2]
non_amp_val_data = non_amp_data[n2:]

len(peptide_data), len(peptide_val_data), len(amp_data), len(amp_val_data), len(non_amp_data), len(non_amp_val_data)


(5284250, 528425, 1447687, 144769, 1447687, 144769)

In [78]:
# Creating Hyperparameters:
n_embd = 128
head_size = 16
n_layer = 4
n_head = 4
batch_size = 32
block_size = 128
dropout = 0.2

# Single Head of Attention:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()

        # K,Q,V Matrices:
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Buffer Matrix and Dropout Layer:
        self.register_buffer('tril', torch.tril(torch.ones([block_size, block_size])))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        
        k = self.key(x)
        q = self.query(x)

        # Determining Affinities with Weighted Sum:
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Adjusting Embedding With Value Matrix:
        v = self.value(x)
        out = wei @ v
        return out

# Parralelization of Attention Heads:
class MultiHeadedAttention(nn.Module):

    def __init__(self, head_size, n_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])

        # Projection and Dropout Layers:
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# Multi-Layer Perceptron:
class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()

        # Linear Layers:
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.GELU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# Self-Attention/MLP Block:
class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head

        # Self-Attention/MLP:
        self.sa = MultiHeadedAttention(head_size, n_head)
        self.ffwd = FeedForward(n_embd)

        # Layer Normalization:
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    # Residual Blocks:
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# AMP Transformer Model:
class AMPTransformer(nn.Module):

    def __init__(self):
        super().__init__()

        # Token and Positional Embedding Tables:
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # Block Layers:
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])

        # Layer Normalization and Unembedding:
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # Projection Head for Contrastive Learning:
        self.projection_head = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.GELU(),
            nn.Linear(n_embd, n_embd // 2),
        )

    def get_embeddings(self, idx):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)

        sequence_emb = torch.mean(x, dim=1)

        contrastive_emb = self.projection_head(sequence_emb)
        contrastive_emb = F.normalize(contrastive_emb, p=2, dim=1)
            
        return contrastive_emb

    def forward(self, idx, targets=None):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        # Determining Loss via Cross Entropy:
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx):

        # Generate New Data Until End Token:
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_new = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_new], dim=1)
            if idx_new == 0:
                break

        return idx

def contrastive_loss(embeddings_amp, embeddings_non_amp, temperature=0.07):
        batch_size = embeddings_non_amp.size(0)
        
        # Concatenating AMP and non-AMP Embeddings:
        all_embeddings = torch.cat([embeddings_amp, embeddings_non_amp], dim=0)

        # Creating Contrastive Labels:
        labels = torch.cat([
            torch.zeros(batch_size, dtype=torch.long, device=embeddings_amp.device),
            torch.ones(batch_size, dtype=torch.long, device=embeddings_amp.device),
        ])

        # Similarity Matrix:
        similarity_matrix = all_embeddings @ all_embeddings.T / temperature

        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
        similarity_matrix = similarity_matrix.masked_fill(mask, float('-inf'))

        labels_expanded = labels.unsqueeze(0)
        positive_mask = (labels_expanded == labels_expanded.T) & ~mask
        negative_mask = (labels_expanded != labels_expanded.T) & ~mask
    
        # InfoNCE loss
        exp_sim = torch.exp(similarity_matrix)
    
        # Positive Pairs:
        positive_sim = exp_sim * positive_mask.float()
        positive_sum = torch.sum(positive_sim, dim=1)
    
        # All pairs:
        all_sum = torch.sum(exp_sim, dim=1)
    
        # Contrastive loss:
        loss = -torch.log(positive_sum / (all_sum + 1e-8))
        loss = torch.mean(loss[positive_sum > 0])  # Only compute loss where positive pairs exist
        
        return loss
        
# Initializing Model:
m = AMPTransformer()
m = m.to(device)
m = torch.compile(m)

# Creating Optimizer:
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)

list(m.parameters())[0]


Parameter containing:
tensor([[ 0.2107,  2.0431,  0.3023,  ..., -0.2195,  0.4472,  0.3479],
        [ 0.1121,  0.1404, -0.3309,  ..., -0.9215,  0.4634,  0.7284],
        [-0.4974, -1.6729,  0.0614,  ...,  0.8258,  0.6698,  0.5455],
        ...,
        [-2.4245,  1.1061, -0.0743,  ...,  0.2002, -0.2998, -0.9666],
        [-1.0869, -0.5832,  0.5542,  ...,  1.1487, -1.4038,  0.6769],
        [-1.1748,  0.6589, -0.3503,  ...,  1.2515,  0.3509,  0.3113]],
       device='mps:0', requires_grad=True)

In [83]:
# Batching Training and Validation Data:
def get_peptide_batch(split):
    data = peptide_train_data if split == 'train' else peptide_val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return xb, yb

# Batching AMP and non-AMP Data:
def get_contrast_batch(split):
    # AMP Batch:
    data = amp_train_data if split == 'train' else amp_val_data
    ix = torch.randint(len(data) - block_size, (batch_size // 2,))
    amp_sequence = torch.stack([data[i:i+block_size] for i in ix])
    amp_sequence = torch.stack([data[i+1:i+block_size+1] for i in ix])

    # non-AMP Batch:
    data = amp_train_data if split == 'train' else amp_val_data
    ix = torch.randint(len(data) - block_size, (batch_size // 2,))
    non_amp_sequence = torch.stack([data[i:i+block_size] for i in ix])
    non_amp_sequence = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return amp_sequence, non_amp_sequence

# Pre-Training with Generalized Peptides:
for steps in range(10000):

    xb, yb = get_amp_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 1000 == 0:
        print(loss.item())


2.7978944778442383
2.5765628814697266
2.6531050205230713
2.5218868255615234
2.48783016204834
2.210139751434326
2.3059821128845215
2.2393100261688232
2.3452038764953613
2.3197412490844727


In [84]:
# Contrastive Learning with AMP and non-AMP's
for steps in range(10000):

    xb, yb = get_amp_batch('train')
    logits, loss = m(xb, yb)

    contrastive_loss_val = 0

    if steps % 5 == 0:

        # Creating AMP and non-AMP Batches:
        amp_sequence, non_amp_sequence = get_contrast_batch('train')

        # Creating Embeddings:
        amp_embeddings = m.get_embeddings(amp_sequence)
        non_amp_embeddings = m.get_embeddings(non_amp_sequence)

        # Finding Distinctions in Embeddings:
        contrastive_loss_val = contrastive_loss(amp_embeddings, non_amp_embeddings, 0.07)

    # Adding Contrastive Loss:
    loss = loss + contrastive_loss_val
        
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 1000 == 0:
        print(loss.item())


2.792475938796997
2.8082025051116943
2.716932535171509
2.922257423400879
2.9218149185180664
2.716519355773926
2.7867369651794434
2.6095669269561768
2.7807822227478027
2.5426905155181885
