In [None]:
import warnings

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

warnings.filterwarnings("ignore")

In [None]:
import sys

sys.path.append("../src")

import constants
import model
import utils

In [None]:
sequences = utils.read_dataset_file(constants.DATASET_PATH_SANDBOX)
len(sequences)

In [None]:
# For test
sequences = sequences[: int(len(sequences) / 20)]

In [None]:
X_left, X_right, y = utils.generate_dataset(sequences)

In [None]:
len(X_left)

In [None]:
X_left_train, X_left_val, X_right_train, X_right_val, y_train, y_val = train_test_split(
    X_left, X_right, y, test_size=0.2, random_state=42
)

In [None]:
train_dataset = model.DNASequenceDataset(X_left_train, X_right_train, y_train)
val_dataset = model.DNASequenceDataset(X_left_val, X_right_val, y_val)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
model = model.DNASequenceModel()

checkpoint_callback = ModelCheckpoint(
    dirpath="weights", filename="{epoch}-{val_loss:.2f}", monitor="val_loss"
)

trainer = Trainer(
    max_epochs=10,
    devices="auto",
    accelerator="auto",
    callbacks=[checkpoint_callback],
    fast_dev_run=False,
)
trainer.fit(model, train_loader, val_loader)