In [3]:
import os

from pandas import read_csv
from torch import save
from torch.optim import AdamW
from sklearn.model_selection import train_test_split

from transformers import BertTokenizerFast, BertForSequenceClassification

from src.dataset import load_dataset, Species, Modification, SeqBunch, load_benchmark_dataset
from src.utils.transformers import encode_seq_bunch, make_dataloader, train_epoch, calculate_acc_dataset

In [4]:
DEVICE = 'cpu'
MODEL = 'bert-base-uncased'

EXPERIMENT_NAME = 'bert-simple'

In [5]:
tokenizer = BertTokenizerFast.from_pretrained(MODEL)

In [8]:
dataset = load_benchmark_dataset(Species.human, Modification.psi)
test_dataset = load_benchmark_dataset(Species.human, Modification.psi, True)

# raw_data = read_csv('/Users/arish/Workspace/experiments/rna_modification/notebook/data/sub_sampled.csv', header=None)
# 
# dataset = SeqBunch(
#     samples=raw_data.drop(1, axis=1).rename({0: 'sequence'}, axis=1),
#     targets=raw_data[1]
# )

In [9]:
sequences, labels = encode_seq_bunch(dataset, tokenizer, True)
sequences_test, labels_test = encode_seq_bunch(test_dataset, tokenizer, True)

In [13]:
x_train, x_val, y_train, y_val = train_test_split(sequences, labels, test_size=0.2)


In [14]:
train_dataloader = make_dataloader(x_train, y_train)
test_dataloader = make_dataloader(sequences_test, labels_test)
val_dataloader = make_dataloader(x_val, y_val)

In [15]:
model = BertForSequenceClassification.from_pretrained(MODEL, num_labels=2)
model.to(DEVICE)

None

In [16]:
optimizer = AdamW(model.parameters(), lr=1e-5)

In [17]:
old_val_acc = 0
old_train_acc = 0
old_model_name = ''

In [18]:
TOTAL_EPOCHS = 0

In [21]:
for epoch in range(1, 10 + 1):
    TOTAL_EPOCHS += 1

    train_acc, val_acc = train_epoch(TOTAL_EPOCHS, DEVICE, model, optimizer, train_dataloader, val_dataloader)
    if train_acc > old_train_acc and val_acc > old_val_acc:
        if old_model_name != '':
            os.unlink(old_model_name)
        old_val_acc = val_acc
        old_train_acc = train_acc
        old_model_name = f'{EXPERIMENT_NAME}_ep-{TOTAL_EPOCHS}_tacc-{train_acc:.2}_vacc-{val_acc:.2}.pt'
        save(model, old_model_name)

In [22]:
calculate_acc_dataset(DEVICE, model, test_dataloader)