In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["WANDB_PROJECT"]="molecular-fingerprinting"

In [None]:
from datasets import DatasetDict, Value
from torch import nn
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
import evaluate
import wandb

In [None]:
ds = DatasetDict.load_from_disk('data/dataset')

ds = ds.rename_column('data', 'inputs')
ds = ds.remove_columns('sex_label')
ds = ds.cast_column('label', Value('float64'))

In [None]:
ds['train'][0]

In [None]:
class CustomModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        activation_layer = nn.ReLU
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(840 * 4, 1024),
            activation_layer(),
            nn.Linear(1024, 128),
            activation_layer(),
            nn.Linear(128, 1),
            # nn.Sigmoid()
        )
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, inputs, labels=None):
        logits = self.net(inputs)
        if labels is not None:
            labels = labels.unsqueeze(1)
            loss = self.loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

model = CustomModel()

In [None]:
training_args = TrainingArguments(
    output_dir='models',
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=1e-3,
    num_train_epochs=100,
    weight_decay=.01,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    bf16=True,
    report_to='wandb'
)

metrics = evaluate.combine(['accuracy', 'precision', 'recall', 'f1'])

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds['train'],
    eval_dataset=ds['dev'],
    compute_metrics=metrics,
    callbacks=[
        EarlyStoppingCallback(10)
    ]
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate(ds['test'])

In [None]:
wandb.finish()