<a href="https://colab.research.google.com/github/YoniRomm/NLP-final/blob/main/NLP_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install accelerate
!pip install numpy
!pip install pandas
!pip install datasets
!pip install torch



In [30]:
import pandas as pd

from jinja2 import Template

from datasets import Dataset, load_dataset


# template_string = """
# {{serialization}}
# Does this patient have diabetes? Yes or no?
# Answer:
# {{ answer_choices }}
# """
#
# template = Template(template_string)
# answer_choices = "No ||| Yes"


# filled_template = template.render(serialization=row_string, answer_choices=answer_choices)

def load_data_set(tokenizer):
    csv_file_path = 'diabetes.csv'
    data_frame = pd.read_csv(csv_file_path)
    training_data = data_frame.sample(frac=0.6, random_state=25)
    other_data = data_frame.drop(training_data.index)
    eval_data = other_data.sample(frac=0.5, random_state=25)
    test_data = other_data.drop(eval_data.index)

    return get_tokenized_data(tokenizer, training_data), get_tokenized_data(tokenizer, eval_data), get_string_data(test_data)


def get_tokenized_data(tokenizer, data_frame):
    data_set = get_string_data(data_frame)

    small_tokenized_dataset = Dataset.from_dict(data_set)

    # You can tokenize the dataset using the tokenizer
    small_tokenized_dataset = small_tokenized_dataset.map(
        lambda examples: tokenizer(examples['text'], padding='max_length', truncation=True),
        batched=True
    )

    return small_tokenized_dataset


def get_string_data(data_frame):
    texts = []
    labels = []
    for index, row in data_frame.iterrows():
        # Construct the formatted string for the current row
        row_string = ', '.join([f'{column}: {value}' for column, value in row.items() if column != "Outcome"])

        texts.append(row_string)
        labels.append(int(row["Outcome"]))

        if index == 9:
            break
    data_set = {
        'text': texts,
        'labels': labels,
    }

    return data_set


In [52]:
from transformers import TrainingArguments, AutoTokenizer, Trainer, DistilBertForSequenceClassification
from transformers import TrainerCallback, EarlyStoppingCallback
import numpy as np
import json

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

train_dataset, eval_dataset, test_dataset = load_data_set(tokenizer)

arguments = TrainingArguments(
    output_dir="sample_hf_trainer",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy="epoch",  # run validation at the end of each epoch
    save_strategy="epoch",
    learning_rate=2e-5,
    load_best_model_at_end=True,
    seed=224,
)


def compute_metrics(eval_pred):
    """Called at the end of validation. Gives accuracy"""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    # calculates the accuracy
    return {"accuracy": np.mean(predictions == labels)}


trainer = Trainer(
    model=model,
    args=arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # change to test when you do your final evaluation!
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


class LoggingCallback(TrainerCallback):
    def __init__(self, log_path):
        self.log_path = log_path

    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            with open(self.log_path, "a") as f:
                f.write(json.dumps(logs) + "\n")


trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=1, early_stopping_threshold=0.0))
trainer.add_callback(LoggingCallback("sample_hf_trainer/log.jsonl"))

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/161 [00:00<?, ? examples/s]

Map:   0%|          | 0/154 [00:00<?, ? examples/s]

In [4]:
# from transformers import TrainerCallback, EarlyStoppingCallback

# class LoggingCallback(TrainerCallback):
#     def __init__(self, log_path):
#         self.log_path = log_path

#     def on_log(self, args, state, control, logs=None, **kwargs):
#         _ = logs.pop("total_flos", None)
#         if state.is_local_process_zero:
#             with open(self.log_path, "a") as f:
#                 f.write(json.dumps(logs) + "\n")


# trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=1, early_stopping_threshold=0.0))
# trainer.add_callback(LoggingCallback("sample_hf_trainer/log.jsonl"))

In [53]:
trainer.train()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.647594,0.649351
2,No log,0.647947,0.649351


TrainOutput(global_step=22, training_loss=0.658111182126132, metrics={'train_runtime': 24.3385, 'train_samples_per_second': 19.845, 'train_steps_per_second': 1.356, 'train_loss': 0.658111182126132, 'epoch': 2.0})

In [33]:
import torch

finetuned_model = DistilBertForSequenceClassification.from_pretrained("sample_hf_trainer/checkpoint-33")

In [47]:
data = test_dataset['text']
labels = test_dataset['labels']

correct = 0
for i in range(len(data)):
  model_inputs = tokenizer(data[i], return_tensors="pt")
  prediction = torch.argmax(finetuned_model(**model_inputs).logits)
  correct += int(prediction.item()) == int(labels[i])

print(f"predicted {correct} out of {len(data)}. percentage {((correct / len(data)) * 100)}")

predicted 95 out of 153. percentage 62.091503267973856
