In [None]:
seed = 42
import random

import torch

random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

from torch import cuda
from transformers import (
    AutoTokenizer,
)
from transformers.modeling_outputs import SequenceClassifierOutput

device = "cuda" if cuda.is_available() else "cpu"

model_ckpt = "distilbert-base-uncased"


import torch
from transformers import AutoTokenizer

# Define your model checkpoint
model_ckpt = "distilbert-base-uncased"

In [None]:
from datasets import load_from_disk

ds = load_from_disk("bld/python/TrainTest/TrainTest_data/")

In [None]:
model = main_function(ds)

In [None]:
model

In [None]:
def main_function(ds):
    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

    # Tokenize the dataset
    ds_encoded = _tokenize_dataset(ds, tokenizer)
    num_labels = 3
    model = BertForMultilabelSequenceClassification.from_pretrained(
        model_ckpt,
        num_labels=num_labels,
    ).to(device)

    return ds_encoded, model


# Function to load the dataset
# def _load_custom_dataset():


# Function to tokenize the dataset
def _tokenize_dataset(ds, tokenizer):
    def tokenize(batch):
        return tokenizer(batch["sequence"], padding=True, truncation=True)

    ds_encoded = ds.map(tokenize, batched=True, batch_size=None)
    ds_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    ds_encoded.set_format("torch")
    return ds_encoded

In [None]:
class BertForMultilabelSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels),
                labels.float().view(-1, self.num_labels),
            )

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss, *output)) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
import pickle

with open("data_model.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
with open("data_model.pkl", "rb") as f:
    loaded_data_model = pickle.load(f)

In [None]:
loaded_data_model

In [None]:
def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):
    y_pred = torch.from_numpy(y_pred)
    y_true = torch.from_numpy(y_true)
    if sigmoid:
        y_pred = y_pred.sigmoid()
    return ((y_pred > thresh) == y_true.bool()).float().mean().item()


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return {"accuracy_thresh": accuracy_thresh(predictions, labels)}

In [None]:
from transformers import TrainingArguments, Trainer

def create_and_train_model(train_dataset, eval_dataset, batch_size=8, num_train_epochs=1,output_dir):
    args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_train_epochs,
        weight_decay=0.01,
    )
    
    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics= compute_metrics,  # not important for problem
        tokenizer=tokenizer,
    )
    
    
    
    return trainer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

In [None]:
dataset = loaded_data_model[0]
model = loaded_data_model[1]

In [None]:
trainer = create_and_train_model(dataset["train_dataset"], dataset["val_dataset"])

In [None]:
trained = trainer.train()

In [None]:
trainer.save_model()

In [None]:
evaluated = trainer.evaluate()

In [None]:
models_to_save = {
    "trained": trained,
    "eval": evaluated,
}

In [None]:
with open("models_to_save.pkl", "wb") as f:
    pickle.dump(models_to_save, f)