In [18]:
# Import Statments:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Adding Device Management:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and set as device.")
else:
    print("MPS is not available on this system.")

MPS is available and set as device.


In [19]:
# Reading Curated AMP.txt file from https://aps.unmc.edu/:
with open('AMP.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Tokenizing Amino Acids:
chars = sorted(list(set(text)))
vocab_size = len(chars)

itos = {ch:i for ch,i in enumerate(chars)}
stoi = {i:ch for ch,i in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[n] for n in l])

data = torch.tensor(encode(text), dtype=torch.long).to(device)

# Creating Training/Validation Split:
n = int(0.9 * len(data))

train_data = data[:n]
val_data = data[n:]

# Creating Training Batches:
block_size = 8
batch_size = 4

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return xb, yb

xb, yb = get_batch('train')
xb, yb

(tensor([[ 8, 14, 16,  6,  3, 17,  2, 12],
         [ 6, 20, 17, 15, 13,  8, 13, 15],
         [10,  2,  5, 10, 13, 10,  1,  1],
         [ 9,  1, 18,  1,  8,  2, 10, 15]], device='mps:0'),
 tensor([[14, 16,  6,  3, 17,  2, 12,  2],
         [20, 17, 15, 13,  8, 13, 15, 13],
         [ 2,  5, 10, 13, 10,  1,  1, 10],
         [ 1, 18,  1,  8,  2, 10, 15,  4]], device='mps:0'))

In [20]:
# Creating Hyperparameters:
n_embd = 256
head_size = 16
n_layer = 4
n_head = 4
batch_size = 32
block_size = 128
dropout = 0.2

# Single Head of Attention:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()

        # K,Q,V Matrices:
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Buffer Matrix and Dropout Layer:
        self.register_buffer('tril', torch.tril(torch.ones([block_size, block_size])))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        
        k = self.key(x)
        q = self.query(x)

        # Determining Affinities with Weighted Sum:
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Adjusting Embedding With Value Matrix:
        v = self.value(x)
        out = wei @ v
        return out

# Parralelization of Attention Heads:
class MultiHeadedAttention(nn.Module):

    def __init__(self, head_size, n_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])

        # Projection and Dropout Layers:
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# Multi-Layer Perceptron:
class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()

        # Linear Layers:
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.GELU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# Self-Attention/MLP Block:
class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head

        # Self-Attention/MLP:
        self.sa = MultiHeadedAttention(head_size, n_head)
        self.ffwd = FeedForward(n_embd)

        # Layer Normalization:
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    # Residual Blocks:
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# AMP Transformer Model:
class AMPTransformer(nn.Module):

    def __init__(self):
        super().__init__()

        # Token and Positional Embedding Tables:
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # Block Layers:
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])

        # Layer Normalization and Unembedding:
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        # Determining Loss via Cross Entropy
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx):

        # Generate New Data Until End Token:
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_new = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_new], dim=1)
            if idx_new == 0:
                break

        return idx
        
# Initializing Model
m = AMPTransformer()
m = m.to(device)
m = torch.compile(m)

# Creating Optimizer:
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)

list(m.parameters())[0]

Parameter containing:
tensor([[ 0.1357,  0.8714,  1.3087,  ...,  1.6417, -2.6251,  1.4841],
        [-1.2811, -0.1951,  1.6697,  ..., -0.4239, -0.1126,  0.8062],
        [ 0.7110, -0.4583,  0.6583,  ...,  0.8311,  1.2196,  0.8259],
        ...,
        [ 0.3012, -2.1751,  0.9615,  ..., -0.3972, -0.8386, -1.4481],
        [ 1.4901,  0.8245, -0.1473,  ...,  1.0899,  0.1246, -0.3761],
        [-0.3681, -0.2876, -0.1821,  ..., -0.6508,  1.6225, -0.4149]],
       device='mps:0', requires_grad=True)

In [21]:
# Creating Training Loop:
steps = 10000

for step in range(steps):

    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f'step: {step} loss: {loss:.4f}')

print(f'total loss: {loss:.4f}')

step: 0 loss: 3.2003
step: 100 loss: 2.8432
step: 200 loss: 2.8349
step: 300 loss: 2.7396
step: 400 loss: 2.7642
step: 500 loss: 2.7089
step: 600 loss: 2.7032
step: 700 loss: 2.6772
step: 800 loss: 2.6437
step: 900 loss: 2.6416
step: 1000 loss: 2.6051
step: 1100 loss: 2.5307
step: 1200 loss: 2.4992
step: 1300 loss: 2.5500
step: 1400 loss: 2.5299
step: 1500 loss: 2.4315
step: 1600 loss: 2.2681
step: 1700 loss: 2.4602
step: 1800 loss: 2.2237
step: 1900 loss: 2.2500
step: 2000 loss: 2.2130
step: 2100 loss: 2.2151
step: 2200 loss: 2.0892
step: 2300 loss: 2.1088
step: 2400 loss: 2.2408
step: 2500 loss: 2.0344
step: 2600 loss: 2.1426
step: 2700 loss: 1.8475
step: 2800 loss: 1.8415
step: 2900 loss: 1.8736
step: 3000 loss: 1.8024
step: 3100 loss: 1.8581
step: 3200 loss: 1.7964
step: 3300 loss: 1.7662
step: 3400 loss: 1.7608
step: 3500 loss: 1.7706
step: 3600 loss: 1.6799
step: 3700 loss: 1.6009
step: 3800 loss: 1.5606
step: 3900 loss: 1.6539
step: 4000 loss: 1.5692
step: 4100 loss: 1.4567
step

In [23]:
# Generating Novel AMP Sequences:
sequences = []

def generate(n, min_length):
    for _ in range(n):
        new_acid_size = 0
        while new_acid_size < min_length:
            idx = torch.zeros([1, 1], dtype=torch.long).to(device)
            new_acid = decode(m.generate(idx)[0].tolist())
            new_acid_size = len(new_acid)
        print(new_acid)
        sequences.append(new_acid)

generate(10, 10)


GCVKVNGNVGGSLNGKAKTAISAGVAAGTVEWGFVSKTYYKGPNFEIPKGKIVCYTVSWGYAGNNTYNIASVWDLLCLTSPGWGTIIVGATAVGNMTFASGGIKH


AIKYDSKKLDPSQVKQKKKVQKK


RRLHQGVRNGKRPQHMYGKFYDAKMHLPYPCRQKVVNWLLLTIQTVVPLKQ


NLKTYPKPTPQKFPTPYEHPIILPNGPNFPSQELGGAPKCALNCVTESDPPLIAGCKACCLDPHTCEPTHHICKLLCKDLS


SAVILDTLKAAGKGALQGLLSTASCKLKNMASGC


YGSEDVCFKPKCPDGQLICGKPFKCECFDSHSCKCPLNKVCLDPI


CTCPDLSLKSKFVNDAKCKTITQELCAKSEKNGSKKNCWDKRRSELLDRPPR


TSLLEPDDKKLIQMGPTVSPKILNEKSKIAYGFTNISNIKEWQSTSCNDLKWHSPWNPTACELLNTYSCNCEKFLHDDICAKKVDGRDVRDAVIVVVLDSGIGGGVSPDFGNNLFGHNTSGSEYSSSSLSYSVTYKSSGSLSS


FLGKMKVNFGPAIMAIAKHFAKKHL


GFFTAYCDVVSKKCAAAHMNKRRCKLTGCKPKDYS

