In [184]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np

batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 4096
max_iters = 2001
eval_interval = 200
learning_rate = 1e-4
eval_iters = 50
mask_prob = .15

hidden_sizes = [64, 128]
latent_size = 256  # Compressed vector size
input_size = 22  # Length of the input sequence

# Check if CUDA is available and set the device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda:0


In [185]:
df = pd.read_csv("uniprotkb_length_TO_5000_AND_reviewed_t_2024_09_20 (1).tsv")
seqs = [str(v[0]).split('\t')[1] for v in df.values[1:]]

print(len(seqs), "#seqs")
print(sum([len(s) for s in seqs]), "tokens")
print(max([len(s) for s in seqs]), "max len")

20398 #seqs
11151556 tokens
4981 max len


In [186]:
text = "".join(seqs)

# here are all the unique characters that occur in this text
chars = sorted(list(set(text))) + ['<MASK>']
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

def encode(sequence):
    """One-hot encode a protein sequence."""
    one_hot = np.zeros((len(sequence), vocab_size))
    for i, aa in enumerate(sequence):
        one_hot[i, stoi[aa]] = 1
    return one_hot

def decode(one_hot_seq):
    """Decode a one-hot encoded protein sequence."""
    return "".join([itos[idx.item()] for idx in np.argmax(one_hot_seq, axis=1)])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.float)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    return torch.stack([data[i:i+block_size] for i in ix]).to(device)

In [209]:
def mask(input_batch):
    batch_size, block_size, input_dim = input_batch.shape
    
    # Create a mask with the same shape as the batch
    mask = (torch.rand(batch_size, block_size) < mask_prob).to(device)

    mask_token = torch.zeros(input_dim).to(device)  # You can customize this to a unique token
    mask_token[-1] = 1

    # Apply the mask to the batch (replace masked vectors with the mask token)
    masked_batch = input_batch.clone().to(device)
    masked_batch[mask] = mask_token  
    
    return masked_batch

In [210]:
# Define the Autoencoder model
class SequenceAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_sizes, latent_size):
        super(SequenceAutoencoder, self).__init__()

        # Encoder
        encoder_layers = []
        prev_size = input_size
        for h in hidden_sizes:
            encoder_layers.append(nn.Linear(prev_size, h))
            encoder_layers.append(nn.ReLU())
            prev_size = h
        encoder_layers.append(nn.Linear(prev_size, latent_size))  # Final layer to latent space
        self.encoder = nn.Sequential(*encoder_layers)
        
        # Decoder
        decoder_layers = []
        prev_size = latent_size
        for h in reversed(hidden_sizes):
            decoder_layers.append(nn.Linear(prev_size, h))
            decoder_layers.append(nn.ReLU())
            prev_size = h
        decoder_layers.append(nn.Linear(prev_size, input_size))  # Final layer to reconstruct input
        decoder_layers.append(nn.Softmax(dim=1))  # For one-hot encoded data
        self.decoder = nn.Sequential(*decoder_layers)

    def forward(self, x):
        # Compress the sequence
        encoded = self.encoder(x)
        # Reconstruct the sequence
        decoded = self.decoder(encoded)
        return decoded

model = SequenceAutoencoder(input_size, hidden_sizes, latent_size).to(device)
model

SequenceAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=22, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=22, bias=True)
    (5): Softmax(dim=1)
  )
)

In [211]:
criterion = nn.MSELoss()  # Mean squared error loss
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

@torch.no_grad()
def estimate_loss():
    loss_out = {}
    acc_out = {}
    
    model.eval()
    for split in ['train', 'val']:
        total_correct = 0
        total_elements = 0
        
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X = get_batch(split)
            new_X = model(X)
            loss = criterion(new_X, X)
            
            losses[k] = loss.item()
            
            for j in range(len(X)):
                pred_seq = new_X[j].argmax(dim=-1)  # Predicted sequence using argmax
                actual_seq = X[j].argmax(dim=-1)  # Ground truth sequence using argmax (if needed)
                correct = (pred_seq == actual_seq).sum().item()  # Count the number of correct predictions
                total_correct += correct
                total_elements += len(X[j])
            
        loss_out[split] = losses.mean()
        acc_out[split] = total_correct / total_elements
        
    model.train()
    return loss_out, acc_out

In [212]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses, accs = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, train acc {accs['train']:.4f}, val acc {accs['val']:.4f}")

    # sample a batch of data
    x = get_batch('train')
    masked_x = mask(x)
    # evaluate the loss
    new_x = model(masked_x)
    loss = criterion(new_x, x)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 0.0497, val loss 0.0497, train acc 0.0138, val acc 0.0157
step 200: train loss 0.0410, val loss 0.0409, train acc 0.5967, val acc 0.6030
step 400: train loss 0.0259, val loss 0.0265, train acc 0.9258, val acc 0.9190
step 600: train loss 0.0216, val loss 0.0217, train acc 0.9657, val acc 0.9644
step 800: train loss 0.0198, val loss 0.0199, train acc 0.9891, val acc 0.9885
step 1000: train loss 0.0182, val loss 0.0185, train acc 0.9988, val acc 0.9991
step 1200: train loss 0.0177, val loss 0.0179, train acc 1.0000, val acc 0.9996
step 1400: train loss 0.0177, val loss 0.0175, train acc 1.0000, val acc 0.9998
step 1600: train loss 0.0173, val loss 0.0175, train acc 0.9990, val acc 0.9994
step 1800: train loss 0.0173, val loss 0.0174, train acc 0.9999, val acc 1.0000
step 2000: train loss 0.0175, val loss 0.0174, train acc 0.9999, val acc 1.0000


In [213]:
@torch.no_grad()
def show_construction():
    for split in ['train', 'val']:
        print('split:', split)
        X = get_batch(split)
        new_X = model(X)
        for j in range(len(X)):
            print(j, decode(new_X[j].cpu()))
            print(j, decode(X[j].cpu()))
            print('===========')
show_construction()

split: train
0 SDPQRNFK
0 SDPQRNFK
1 APASPQPP
1 APASPQPP
2 ERDGDRRL
2 ERDGDRRL
3 TASTRGFY
3 TASTRGFY
4 ICPIPKEV
4 ICPIPKEV
5 QFSSYVGR
5 QFSSYVGR
6 PEGPRGAA
6 PEGPRGAA
7 VCPLSWFG
7 VCPLSWFG
8 VYYVGDTF
8 VYYVGDTF
9 LIKARALN
9 LIKARALN
10 KFQLLVQQ
10 KFQLLVQQ
11 EVSSQGRE
11 EVSSQGRE
12 RFMPEPNL
12 RFMPEPNL
13 PAVAESAV
13 PAVAESAV
14 QPTEDNIH
14 QPTEDNIH
15 QHLRIHLG
15 QHLRIHLG
16 MVRMKSMF
16 MVRMKSMF
17 TEALSMAH
17 TEALSMAH
18 GSTVQSVD
18 GSTVQSVD
19 SPGKQASS
19 SPGKQASS
20 DGQRVKLK
20 DGQRVKLK
21 TDLSFDSQ
21 TDLSFDSQ
22 GPGTVCES
22 GPGTVCES
23 DLAEYFRL
23 DLAEYFRL
24 DYSPPLHK
24 DYSPPLHK
25 PVQEELSV
25 PVQEELSV
26 QQEHFVLS
26 QQEHFVLS
27 TEDIRERV
27 TEDIRERV
28 QISGGFLV
28 QISGGFLV
29 RIIKHQKM
29 RIIKHQKM
30 IEPDGAEL
30 IEPDGAEL
31 VNAYSHKF
31 VNAYSHKF
split: val
0 GALCAGLG
0 GALCAGLG
1 IPVEQLML
1 IPVEQLML
2 AEKLRKME
2 AEKLRKME
3 RDSGTNAQ
3 RDSGTNAQ
4 PLPAHYRS
4 PLPAHYRS
5 SGAGRATA
5 SGAGRATA
6 EAVAFVVP
6 EAVAFVVP
7 LGDRIITP
7 LGDRIITP
8 PGLAPPQK
8 PGLAPPQK
9 EPGRIQHP
9 EPGRIQHP
10 FKWSS