<div align=center>

# Binary classification workflow 📊

</div>

### Import required packages 📦

In [None]:
from safetensors.torch import save_file
from safetensors.torch import safe_open
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import torch

### Configuration for optimized computation and plotting

In [None]:
# Set float16 as dtype for more optimized computation
torch.set_default_dtype(torch.half)

# Set the plot figsize
plot_figsize = (12, 8)

# Set the global variables
train_data_color = "green"
val_data_color = "orange"
pred_color = "red"

### 1. Get data ready

#### Define plotting functions

In [None]:
# Define data plotting function
def plot_data(train_data, val_data, train_labels, val_labels, title="Data plot"):
    # Set the font to be JetBrains Mono
    font_path = Path("../fonts/JetBrainsMono-Regular.ttf")

    # Set the dark mode
    plt.style.use('dark_background')

    # Create the plot
    plt.figure(figsize=plot_figsize)
    plt.scatter(train_data, train_labels, c=train_data_color)
    plt.plot(train_data, train_labels, c=train_data_color)

    plt.scatter(val_data, val_labels, c=val_data_color)
    plt.plot(val_data, val_labels, c=val_data_color)
    
    # Setting up axes
    plt.title(title, font=font_path, fontsize=16)
    plt.xlabel('X', font=font_path, fontsize=14)
    plt.ylabel('Y', font=font_path, fontsize=14)

    # Customize ticks
    plt.yticks([0, 1], ['Not odd', 'Odd'])

#### Generate example data

In [None]:
# Create the dictionary
data_dict = {i: (i % 2 == 0) for i in range(51)}

# Convert the dictionary to a list of tuples
data_items = list(data_dict.items())

# Set the size of the training set
train_size = int(len(data_items) * 0.8)

# Split into training and validation data
train_data = data_items[:train_size]
val_data = data_items[train_size:]

# Create X and y for the training set
X_train = [item[0] for item in train_data]  # Numbers
y_train = [item[1] for item in train_data]  # Labels

# Create X and y for the validation set
X_val = [item[0] for item in val_data]      # Numbers
y_val = [item[1] for item in val_data]      # Labels