In [1]:
# HYPERPARAMETERS
NUM_CLASSES=1_000
BATCH_SIZE = 2**6
MAX_SEQ_LENGTH = 256
LEARNING_RATE = 3e-5
EPOCHS = 10

In [2]:
!pip install torchtext pyarrow transformers structlog

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [3]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
tqdm.pandas()

import torch
torch.manual_seed(42)
import torch.nn as nn
import torchtext
import gc
import structlog

GPU = True
device = torch.device("cuda" if GPU else "cpu")
data_path = '/home/ubuntu/'
logger = structlog.getLogger()
logger.info(f"Getting started with {GPU=} {device=} {data_path=}")

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [4]:
!tree /home/ubuntu/

[01;34m/home/ubuntu/[00m
├── [01;31mparquet_data.zip[00m
├── [01;34msnap[00m
│   └── [01;34mnvtop[00m
│       ├── [01;34m66[00m
│       ├── [01;34mcommon[00m
│       └── [01;36mcurrent[00m -> [01;34m66[00m
├── test_df.parquet
├── train_df.parquet
├── try1.ipynb
└── val_df.parquet

5 directories, 5 files


In [5]:
train_df = pd.read_parquet(data_path + 'train_df.parquet')
test_df = pd.read_parquet(data_path + 'test_df.parquet')
val_df = pd.read_parquet(data_path + 'val_df.parquet')
train_df.shape, test_df.shape, val_df.shape

((1086741, 5), (126171, 5), (126171, 5))

In [6]:
torch.__version__, torchtext.__version__

('1.13.1+cu117', '0.14.1')

In [7]:
# Load protbert model
from transformers import BertForMaskedLM, BertTokenizer, pipeline
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
if GPU:
    protbert_model.cuda()
import torch
import re


x_aminos = re.compile("[UZOB]")

def prepare_input(seq, **tokenizer_args):
    seq = x_aminos.sub("X", ' '.join(seq))
    input_ids = tokenizer.encode(seq, add_special_tokens=True, **tokenizer_args)
    return input_ids

def get_embeddings(seq, **tokenizer_args):
    input_ids = prepare_input(seq, **tokenizer_args)
    input_ids = torch.tensor([input_ids], device='cuda' if GPU else 'cpu')
    with torch.no_grad():
        return protbert_model(input_ids)

eg_input = get_embeddings("RRWWRRRRW")
eg_input

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


MaskedLMOutput(loss=None, logits=tensor([[[-1.8532e+01, -1.9899e+01, -1.9202e+01, -1.9233e+01, -2.1818e+01,
          -3.3183e-01, -4.3617e-01, -1.6150e-01, -2.5575e-02, -1.2855e+00,
           4.6539e-01, -1.3698e+00,  2.5481e-02,  1.3961e+00, -2.3760e+00,
          -2.5161e-01, -6.2316e-01, -9.2688e-01, -8.8103e-01, -1.4678e+00,
          -1.6217e+00,  3.2145e+00, -2.0604e+00, -4.5068e-01,  4.5483e-01,
          -3.9251e+00, -1.8108e+01, -1.8154e+01, -1.8443e+01, -1.8931e+01],
         [-2.1589e+01, -2.1636e+01, -2.1418e+01, -2.1295e+01, -2.2876e+01,
          -8.4695e-01, -8.1240e-01,  7.8902e-01, -1.7455e+00, -1.4653e+00,
           5.6490e-01, -1.1674e+00,  1.2356e+00,  4.1438e+00, -2.4756e+00,
           2.0495e+00, -1.3056e+00,  7.8864e-01,  1.2412e+00, -2.7056e+00,
          -1.4608e+00, -3.3685e-01, -2.7666e+00,  2.4256e-01,  3.9161e+00,
          -6.9252e+00, -2.1043e+01, -2.1170e+01, -2.0820e+01, -2.1201e+01],
         [-2.1355e+01, -2.0936e+01, -2.0803e+01, -1.9996e+01, -2.

In [8]:
# Create dataloaders
from torch.utils.data import Dataset, DataLoader

class ProteinSequenceDataset(Dataset):
    def __init__(self, df, sequence_col='sequence', label_col='family_id', max_len=100):
        self.df = df.reset_index(drop=True)
        self.sequence_col = sequence_col
        self.label_col = label_col
        self.max_len = max_len
        self._label_translator = {l: torch.tensor(i, device='cuda' if GPU else 'cpu') for i, l in enumerate(sorted(df[label_col].unique()))}

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        seq, label = row[self.sequence_col], row[self.label_col]
        # Convert to tensor
        seq = prepare_input(seq[:self.max_len-10], padding='max_length', max_length=self.max_len)
        seq = torch.tensor(seq, device='cuda' if GPU else 'cpu')
        label = self._label_translator[label]
        return seq, label


In [9]:
# TODO: Convert train and test to same top n families
top_families = train_df['family_id'].value_counts()[:NUM_CLASSES]
# Convert to numbers
fam2id = {fam: i for i, fam in enumerate(top_families.index)}
list(fam2id.items())[:5]

[('Methyltransf_25', 0),
 ('LRR_1', 1),
 ('Acetyltransf_7', 2),
 ('His_kinase', 3),
 ('Bac_transf', 4)]

In [10]:
def add_and_filter_family_id(df):
    df['family_code'] = df['family_id'].apply(lambda x: fam2id.get(x, np.nan))
    logger.info(f'Removing {df["family_code"].isna().sum():,}/{len(df):,} = {df["family_code"].isna().mean()*100:,.6f}% of rows due to nan famid num.')
    return df.dropna(subset='family_code').reset_index(drop=True)

train_df = add_and_filter_family_id(train_df)
test_df = add_and_filter_family_id(test_df)
val_df = add_and_filter_family_id(val_df)

Removing 647,248/1,086,741 = 59.558625% of rows due to nan famid num.
Removing 71,793/126,171 = 56.901348% of rows due to nan famid num.
Removing 71,793/126,171 = 56.901348% of rows due to nan famid num.


In [11]:
import gc; gc.collect()
train_df.shape, test_df.shape, val_df.shape

((439493, 6), (54378, 6), (54378, 6))

In [12]:
train_dataset = ProteinSequenceDataset(train_df, label_col='family_code', max_len=MAX_SEQ_LENGTH)
test_dataset  = ProteinSequenceDataset(test_df, label_col='family_code', max_len=MAX_SEQ_LENGTH)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))
train_features, train_labels

(tensor([[ 2, 14, 16,  ...,  0,  0,  0],
         [ 2,  6,  5,  ...,  0,  0,  0],
         [ 2,  8, 15,  ...,  0,  0,  0],
         ...,
         [ 2, 14, 12,  ...,  0,  0,  0],
         [ 2,  5, 19,  ...,  0,  0,  0],
         [ 2, 16, 10,  ...,  0,  0,  0]], device='cuda:0'),
 tensor([433, 289, 117, 749, 215,  85,  59,  31,  63, 185, 494,  46, 986, 298,
         383, 107, 865, 314, 164, 441, 337, 199, 413, 348, 320, 367, 688, 223,
         618, 392, 541, 906,  65, 454, 831, 423, 229, 259, 465, 468,  14,  43,
          11, 178,   8,  81, 417,  33,  55, 489, 721, 962, 875, 601, 174, 350,
         747, 161, 257, 367,  35, 277, 228,  64], device='cuda:0'))

In [13]:
# Transformer models from tutorial https://n8henrie.com/2021/08/writing-a-transformer-classifier-in-pytorch/

class PositionalEncoding(nn.Module):
    """
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(self, d_model, vocab_size=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)


class Net(nn.Module):
    """
    Text classifier based on a pytorch TransformerEncoder.
    """

    def __init__(
        self,
        embeddings,
        vocab_size=30,
        embedding_size=1024,
        nhead=8,
        dim_feedforward=2048,
        num_layers=6,
        num_labels=2,
        dropout=0.1,
        activation="relu",
        classifier_dropout=0.1,
    ):

        super().__init__()

        d_model = embedding_size
        assert d_model % nhead == 0, "nheads must divide evenly into d_model"

        self.emb = embeddings

        self.pos_encoder = PositionalEncoding(
            d_model=d_model,
            dropout=dropout,
            vocab_size=vocab_size,
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )
        self.classifier = nn.Linear(d_model, num_labels)
        self.d_model = d_model
        self._agg_type = 1

    def forward(self, x):
        with torch.no_grad():
            embeds = self.emb(x)
            # x = embeds[:,-10:,:] # Only need last ten
            x = embeds
        # x = self.emb(x) * math.sqrt(self.d_model)
        # x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        if self._agg_type == 0:
            x = x[:, -1, :]
        else:
            x = x.mean(1)
        x = self.classifier(x)
        return x

In [14]:
tf_model = Net(
    protbert_model.bert.embeddings,
    vocab_size=tokenizer.vocab_size,
    nhead=8,  # the number of heads in the multiheadattention models
    dim_feedforward=50,  # the dimension of the feedforward network model in nn.TransformerEncoder
    num_layers=6,
    num_labels=NUM_CLASSES,
    dropout=0.2,
    classifier_dropout=0.2,
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    tf_model.parameters(), lr=LEARNING_RATE
)

In [27]:
optimizer.zero_grad()
torch.cuda.empty_cache()
gc.collect()

8

In [25]:
def train_one_epoch(epoch, log_n=100):
    epoch_loss = 0
    epoch_correct = 0
    epoch_count = 0
    for idx, (inputs, labels) in enumerate(train_dataloader):
        optimizer.zero_grad()
        predictions = tf_model(inputs)
        loss = loss_fn(predictions, labels)

        correct = predictions.argmax(axis=1) == labels
        acc = correct.sum().item() / correct.size(0)
        epoch_correct += correct.sum().item()
        epoch_count += correct.size(0)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

        if idx % log_n == log_n-1:
            logger.info(f'{epoch=} {idx:,}/{len(train_dataloader):,} {epoch_loss=:,.6f}/{epoch_count=:,} = {epoch_loss/epoch_count:,.6f} {epoch_correct:,}/{epoch_count:,} = {100*epoch_correct/epoch_count:,.6f}%')

    logger.info(f"{epoch=} {epoch_loss=}")
    logger.info(f"{epoch=} accuracy: {epoch_correct / epoch_count}")

In [29]:
def test_one_epoch(epoch, log_n=100):
    with torch.no_grad():
        test_epoch_loss = 0
        test_epoch_correct = 0
        test_epoch_count = 0

        for idx, (inputs, labels) in enumerate(test_dataloader):
            predictions = tf_model(inputs)
            test_loss = loss_fn(predictions, labels)

            correct = predictions.argmax(axis=1) == labels
            acc = correct.sum().item() / correct.size(0)

            test_epoch_correct += correct.sum().item()
            test_epoch_count += correct.size(0)
            test_epoch_loss += loss.item()
            
            if idx % log_n == log_n-1:
                logger.info(f'{epoch=} {idx:,}/{len(test_dataloader):,} {test_epoch_loss=:,.6f}/{test_epoch_count=:,} = {test_epoch_loss/test_epoch_count:,.6f} {test_epoch_correct:,}/{test_epoch_count:,} = {100*test_epoch_correct/test_epoch_count:,.6f}%')

    logger.info(f"{epoch=} {test_epoch_loss=}")
    logger.info(f"test {epoch=} accuracy: {test_epoch_correct=:,}/{test_epoch_count:,} = {100 * test_epoch_correct / test_epoch_count:,.6f}%")

In [None]:
# Train
# with torch.autocast(device_type='cuda' if GPU else 'cpu'):
for epoch in range(EPOCHS):
    logger.info(f"{epoch=}")
    train_one_epoch(epoch)
    test_one_epoch(epoch)

2023-01-21 18:04:53 [info     ] epoch=0
2023-01-21 18:05:26 [info     ] epoch=0 99/6,868 epoch_loss=565.347676/epoch_count=6,400 = 0.088336 156/6,400 = 2.437500%
2023-01-21 18:05:59 [info     ] epoch=0 199/6,868 epoch_loss=1,128.954273/epoch_count=12,800 = 0.088200 316/12,800 = 2.468750%
2023-01-21 18:06:33 [info     ] epoch=0 299/6,868 epoch_loss=1,692.231605/epoch_count=19,200 = 0.088137 475/19,200 = 2.473958%
2023-01-21 18:07:06 [info     ] epoch=0 399/6,868 epoch_loss=2,253.782454/epoch_count=25,600 = 0.088038 624/25,600 = 2.437500%
2023-01-21 18:07:40 [info     ] epoch=0 499/6,868 epoch_loss=2,813.347023/epoch_count=32,000 = 0.087917 785/32,000 = 2.453125%
2023-01-21 18:08:14 [info     ] epoch=0 599/6,868 epoch_loss=3,373.556302/epoch_count=38,400 = 0.087853 945/38,400 = 2.460938%
2023-01-21 18:08:47 [info     ] epoch=0 699/6,868 epoch_loss=3,932.364497/epoch_count=44,800 = 0.087776 1,097/44,800 = 2.448661%
2023-01-21 18:09:21 [info     ] epoch=0 799/6,868 epoch_loss=4,488.368099/

In [28]:
test_one_epoch(epoch)

2023-01-21 18:02:38 [info     ] epoch=0 29/850 test_epoch_loss=173.171411/test_epoch_count=1,920 = 0.090193 50/1,920 = 2.604167%
2023-01-21 18:02:43 [info     ] epoch=0 59/850 test_epoch_loss=346.342821/test_epoch_count=3,840 = 0.090193 103/3,840 = 2.682292%
2023-01-21 18:02:47 [info     ] epoch=0 89/850 test_epoch_loss=519.514232/test_epoch_count=5,760 = 0.090193 157/5,760 = 2.725694%
2023-01-21 18:02:52 [info     ] epoch=0 119/850 test_epoch_loss=692.685642/test_epoch_count=7,680 = 0.090193 207/7,680 = 2.695312%
2023-01-21 18:02:57 [info     ] epoch=0 149/850 test_epoch_loss=865.857053/test_epoch_count=9,600 = 0.090193 256/9,600 = 2.666667%
2023-01-21 18:03:01 [info     ] epoch=0 179/850 test_epoch_loss=1,039.028463/test_epoch_count=11,520 = 0.090193 304/11,520 = 2.638889%
2023-01-21 18:03:06 [info     ] epoch=0 209/850 test_epoch_loss=1,212.199874/test_epoch_count=13,440 = 0.090193 353/13,440 = 2.626488%
2023-01-21 18:03:10 [info     ] epoch=0 239/850 test_epoch_loss=1,385.371284/te

In [24]:
# Verify its not just giving the same label or getting lucky
with torch.no_grad():
    test_epoch_loss = 0
    test_epoch_correct = 0
    test_epoch_count = 0
    for idx, (inputs, labels) in enumerate(test_dataloader):
        predictions = tf_model(inputs)
        test_loss = loss_fn(predictions, labels)
        print(inputs)
        print(predictions)
        print(test_loss)
        print(predictions.argmax(axis=1))
        print(labels)
        if idx * BATCH_SIZE > 5:
            break

tensor([[ 2, 19,  9,  ...,  0,  0,  0],
        [ 2,  6,  6,  ...,  0,  0,  0],
        [ 2, 16,  8,  ...,  0,  0,  0],
        ...,
        [ 2, 20, 12,  ...,  0,  0,  0],
        [ 2, 11, 21,  ...,  0,  0,  0],
        [ 2,  6, 23,  ...,  0,  0,  0]], device='cuda:0')
tensor([[-0.0440, -5.3381, -0.9999,  ...,  1.2580,  1.1968, -0.1832],
        [-0.5925, -5.6488, -1.5915,  ...,  1.4442,  1.2875,  0.3802],
        [ 3.4140,  0.9251,  3.1872,  ..., -1.3907, -0.8055, -2.7464],
        ...,
        [ 2.4615,  2.9081,  2.8831,  ..., -1.6114, -1.2116, -2.2700],
        [ 3.4396,  0.9174,  3.2128,  ..., -1.4100, -0.7952, -2.7663],
        [-3.4412, -1.6198, -3.4318,  ...,  1.0870,  0.4103,  2.9897]],
       device='cuda:0')
tensor(5.6799, device='cuda:0')
tensor([ 4,  4,  0, 47, 47,  0, 47,  7,  0,  3,  4, 12,  0,  2, 16, 47,  1,  0,
         0,  0,  0, 47,  0,  3,  4,  0,  4,  0,  3,  0,  0,  0,  2, 12,  4,  0,
         0,  7,  0,  0, 47,  0,  0, 47, 47,  2,  0,  2, 47, 47,  0, 47,  0,  0,