In [1]:
import os

from torch import save
from torch.optim import AdamW
from sklearn.model_selection import train_test_split

from transformers import BertTokenizerFast, BertForSequenceClassification

from src.dataset import load_dataset, Species, Modification
from src.utils.transformers import encode_seq_bunch, make_dataloader, train_epoch, calculate_acc_dataset

In [2]:
DEVICE = 'mps'
MODEL = 'bert-base-uncased'

EXPERIMENT_NAME = 'bert-simple-m6a'

In [3]:
tokenizer = BertTokenizerFast.from_pretrained(MODEL)

In [4]:
dataset = load_dataset(Species.yeast, Modification.m6a)

In [15]:
def encode_sequence(sequence: str, tokenizer, split_chars=True):
    max_length = len(sequence)

    if split_chars:
        sequence = ' '.join(list(sequence))

    return tokenizer(sequence, max_length=max_length, add_special_tokens=False)

def encode_seq_bunch(
        bunch,
        tokenizer,
        split_chars=True
) -> tuple[list, list[int]]:
    return list(map(lambda x: encode_sequence(x, tokenizer, split_chars), bunch.samples['sequence'].values)), bunch.targets.values

In [17]:
sequences, labels = encode_seq_bunch(dataset, tokenizer, True)

In [18]:
x_train, x_test, y_train, y_test = train_test_split(sequences, labels, test_size=0.2)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)

In [19]:
train_dataloader = make_dataloader(x_train, y_train)
test_dataloader = make_dataloader(x_test, y_test)
val_dataloader = make_dataloader(x_val, y_val)

In [20]:
model = BertForSequenceClassification.from_pretrained(MODEL, num_labels=2)
model.to(DEVICE)

None

In [21]:
optimizer = AdamW(model.parameters(), lr=2e-5)

In [22]:
old_val_acc = 0
old_train_acc = 0
old_model_name = ''

In [23]:
TOTAL_EPOCHS = 0

In [24]:
for epoch in range(1, 5 + 1):
    TOTAL_EPOCHS += 1

    train_acc, val_acc = train_epoch(TOTAL_EPOCHS, DEVICE, model, optimizer, train_dataloader, val_dataloader)
    if train_acc > old_train_acc and val_acc > old_val_acc:
        if old_model_name != '':
            os.unlink(old_model_name)
        old_val_acc = val_acc
        old_train_acc = train_acc
        old_model_name = f'{EXPERIMENT_NAME}_ep-{TOTAL_EPOCHS}_tacc-{train_acc:.2}_vacc-{val_acc:.2}.pt'
        save(model, old_model_name)

In [25]:
calculate_acc_dataset(DEVICE, model, test_dataloader)