In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm.notebook import tqdm

from bio_if.data.utils import FastaDataset

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
tokenizer_fn = lambda x: tokenizer(x, return_tensors="pt")['input_ids']

In [3]:
train = FastaDataset('../src/bio_if/data/FLIP/gb1/sampled.fasta', split='train', tokenizer_fn=tokenizer_fn)
val = FastaDataset('../src/bio_if/data/FLIP/gb1/sampled.fasta', split='val', tokenizer_fn=tokenizer_fn)
test = FastaDataset('../src/bio_if/data/FLIP/gb1/sampled.fasta', split='test', tokenizer_fn=tokenizer_fn)

In [4]:
# define an MLP
def init_regressor():
    return nn.Sequential(
        nn.Embedding(tokenizer.vocab_size, 4),
        nn.Flatten(),
        nn.Linear((len(train.seqs[0]) + 2) * 4, 256),
        nn.GELU(),
        nn.Linear(256, 1)
    )

In [5]:
train_dataloader = train.get_dataloader(batch_size=32, shuffle=True, drop_last=True)
val_dataloader = val.get_dataloader(batch_size=32, shuffle=False, drop_last=False)
test_dataloader = test.get_dataloader(batch_size=32, shuffle=False, drop_last=False)

In [6]:
def get_val_loss(model, dataloader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, labels = batch
            outputs = model(input_ids).squeeze()
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    return val_loss / len(dataloader)

In [8]:
EPOCHS = 50
N_VAL_ATTEMPTS = 5
LR = 1e-3

loss_fn = nn.MSELoss()
best_val_loss = float('inf')

val_failures = 0
model = init_regressor()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for batch in tqdm(train_dataloader):
        input_ids, labels = batch
        outputs = model(input_ids).squeeze()
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    print(f'Epoch {epoch} train loss: {train_loss / len(train_dataloader)}')

    val_loss = get_val_loss(model, val_dataloader, loss_fn)
    print(f'Epoch {epoch} val loss: {val_loss}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        val_failures = 0
    else:
        val_failures += 1
        if val_failures >= N_VAL_ATTEMPTS:
            break

  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 0 train loss: 1.4745827712574784
Epoch 0 val loss: 1.2371189919385044


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1 train loss: 1.2322396167687006
Epoch 1 val loss: 1.1211298500949687


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 2 train loss: 1.1257912601743425
Epoch 2 val loss: 1.2040292051705448


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 3 train loss: 1.045514323121431
Epoch 3 val loss: 0.9208378832448613


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 4 train loss: 0.9679551156503814
Epoch 4 val loss: 0.841795802116394


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 5 train loss: 0.9602655204279082
Epoch 5 val loss: 0.8688441772352565


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 6 train loss: 0.9064691107035899
Epoch 6 val loss: 0.8191796283830296


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 7 train loss: 0.842752393335104
Epoch 7 val loss: 0.7374807928096164


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 8 train loss: 0.8358279484875348
Epoch 8 val loss: 0.7130411829460751


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 9 train loss: 0.8280331495465064
Epoch 9 val loss: 0.7234124453230337


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 10 train loss: 0.7837638836734149
Epoch 10 val loss: 0.6935737918723713


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 11 train loss: 0.7749865956756533
Epoch 11 val loss: 0.6822097558866848


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 12 train loss: 0.7502836960614944
Epoch 12 val loss: 0.6627548933029175


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 13 train loss: 0.7276112239579765
Epoch 13 val loss: 0.6493181599812075


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 14 train loss: 0.7179111508386475
Epoch 14 val loss: 0.6228292394768108


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 15 train loss: 0.7111637469153015
Epoch 15 val loss: 0.6371993950822137


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 16 train loss: 0.7387100702949932
Epoch 16 val loss: 0.6582648605108261


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 17 train loss: 0.6959110617029424
Epoch 17 val loss: 0.6684293516657569


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 18 train loss: 0.6831985376015002
Epoch 18 val loss: 0.6137132082473148


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 19 train loss: 0.6866612692876738
Epoch 19 val loss: 0.5957975428212773


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 20 train loss: 0.6605402234257484
Epoch 20 val loss: 0.599531730467623


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 21 train loss: 0.6476936491624433
Epoch 21 val loss: 0.6790915111249144


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 22 train loss: 0.6421550871158133
Epoch 22 val loss: 0.5864144617860968


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 23 train loss: 0.6531702666258326
Epoch 23 val loss: 0.6038193384354765


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 24 train loss: 0.6664134633480286
Epoch 24 val loss: 0.5806613483212211


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 25 train loss: 0.6450635116471319
Epoch 25 val loss: 0.5966368981383063


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 26 train loss: 0.6318945593523736
Epoch 26 val loss: 0.5802516761151227


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 27 train loss: 0.6534937212661821
Epoch 27 val loss: 0.5832917222922499


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 28 train loss: 0.6272897415851452
Epoch 28 val loss: 0.5616788234223019


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 29 train loss: 0.6240709004353504
Epoch 29 val loss: 0.6174886694008653


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 30 train loss: 0.6007734585021224
Epoch 30 val loss: 0.5632519471374425


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 31 train loss: 0.6269027928308565
Epoch 31 val loss: 0.5575886775146831


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 32 train loss: 0.6118111587604698
Epoch 32 val loss: 0.5539098598740317


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 33 train loss: 0.5911510712182035
Epoch 33 val loss: 0.6213912706483494


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 34 train loss: 0.5851847586431065
Epoch 34 val loss: 0.5662805549800396


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 35 train loss: 0.6083347632416657
Epoch 35 val loss: 0.5573306483301249


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 36 train loss: 0.596706937694428
Epoch 36 val loss: 0.5817333622412248


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 37 train loss: 0.587603896628229
Epoch 37 val loss: 0.6193950569087808
