In [None]:
import os
from pathlib import Path
import cv2
import torch
from torch import nn
from torch import optim
from sklearn.datasets import make_moons, make_circles, make_blobs, load_iris
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = 'cpu'
torch.set_default_device(device)

In [None]:
def plot_decision_boundary(model: torch.nn.Module,
                           X: torch.Tensor,
                           y: torch.Tensor,
                           return_device: str = "cuda"):
    """Plots decision boundaries of model predicting on X in comparison to y.

    Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
    """
    # Put everything to CPU (works better with NumPy + Matplotlib)
    model.to("cpu")
    X, y = X.to("cpu"), y.to("cpu")

    # Setup prediction boundaries and grid
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))

    # Make features
    X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()

    # Make predictions
    model.eval()
    with torch.inference_mode():
        y_logits = model(X_to_pred_on)

    # Test for multi-class or binary and adjust logits to prediction labels
    if len(torch.unique(y)) > 2:
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)  # mutli-class
    else:
        y_pred = torch.round(torch.sigmoid(y_logits))  # binary

    # Reshape preds and plot
    y_pred = y_pred.reshape(xx.shape).detach().numpy()
    plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
    plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())

    X.to(device=device)
    y.to(device=device)

In [None]:
def show_plot(X, y, save_folder: Path = None):
    plt.figure(figsize=(10, 10))
    if save_folder:
        plt.savefig(save_folder / "plot.png")
        plt.close()
    else:
        plt.show()

In [None]:
def train_model(model: nn.Module,
                X: torch.Tensor,
                y: torch.Tensor,
                epochs: int = 100,
                lr: float = 0.01,
                optimizer_class: optim = optim.Adam,
                loss_class: nn = nn.BCEWithLogitsLoss,
                save_folder: Path = None,
                convert_to_video: bool = True) -> None:
    """

    :param model: your model
    :param X: your train data
    :param y: your labels (desired output)
    :param epochs:
    :param lr: learning rate (step)
    :param optimizer_class: optimizer
    :param loss_class: lostt
    :param save_folder: folder to save your images **BE CAREFULL**
    :param convert_to_video: convert sequence to video
    :return: None
    """
    loss_fn = loss_class()
    optimizer = optimizer_class(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model_prediction = model(X)
        loss = loss_fn(model_prediction, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if save_folder:
            plt.figure(figsize=(10, 10))
            plot_decision_boundary(model, X, y)
            plt.savefig(save_folder / f"{epoch}.jpg")
            plt.close()

        if epoch % 10 == 0:
            print(f"Epoch [{epoch}/{epochs}] | Loss: {loss.item():.4f}")

    if convert_to_video and save_folder:
        images = sorted([img for img in os.listdir(save_folder) if img.endswith(".jpg")],
                        key=lambda x: int(x.split(".")[0]))

        frame = cv2.imread(os.path.join(save_folder, images[0]))
        height, width, _ = frame.shape

        video_name = str(save_folder / 'model_timelapse.mp4')
        fps = 60
        video_codec = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
        video_writer = cv2.VideoWriter(video_name, video_codec, fps, (width, height))
        for image in images:
            frame = cv2.imread(os.path.join(save_folder, image))
            video_writer.write(frame)
            os.remove(save_folder / image)
        video_writer.release()
        print("Video saved!")

In [None]:
X_moons, y_moons = make_moons(n_samples=1000, noise=0.03, random_state=42)
X_circles, y_circles = make_circles(n_samples=1000, noise=0.03, random_state=42)
X_blobs, y_blobs = make_blobs(n_samples=1000, cluster_std=0.5, n_features=2, centers=8, random_state=42)

In [None]:
n_samples = 1000
dimensions = 2
classes = 3
X_spirals = np.zeros((n_samples * classes, dimensions))  # data matrix (each row = single example)
y_spirals = np.zeros(n_samples * classes, dtype='uint8')  # class labels
for j in range(classes):
    ix = range(n_samples * j, n_samples * (j + 1))
    radius = np.linspace(0.0, 1, n_samples)  # radius
    theta = np.linspace(j * 4, (j + 1) * 4, n_samples) + np.random.randn(n_samples) * 0.2
    X_spirals[ix] = np.c_[radius * np.sin(theta), radius * np.cos(theta)]
    y_spirals[ix] = j

In [None]:
X_moons, X_circles, X_blobs, X_spirals, y_moons, y_circles, y_blobs, y_spirals = map(
    lambda x: torch.tensor(x, dtype=torch.float32, device=device),
    (X_moons, X_circles, X_blobs, X_spirals, y_moons, y_circles, y_blobs, y_spirals)
)

In [None]:
y_moons, y_circles = map(
    lambda x: x.unsqueeze(1),
    (y_moons, y_circles)
)

In [None]:
y_blobs, y_spirals = y_blobs.long(), y_spirals.long()

##### Binary (one or another)

In [None]:
moons_model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
)

In [None]:
circles_model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
)

##### Multi-class

In [None]:
blobs_model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
)

In [None]:
spirals_model = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU(),
    nn.Linear(8, 8),
    nn.ReLU(),
    nn.Linear(8, 3),
)

In [None]:
save_moons = Path("moons")
save_circles = Path("circles")
save_blobs = Path("blobs")
save_spirals = Path("spirals")

In [None]:
save_moons.mkdir(exist_ok=True)
save_circles.mkdir(exist_ok=True)
save_blobs.mkdir(exist_ok=True)
save_spirals.mkdir(exist_ok=True)

In [None]:
# load models' weights
# try:
#     moons_model.load_state_dict(torch.load("weights/moons_model.pth"))
#     circles_model.load_state_dict(torch.load("weights/circles_model.pth"))
#     blobs_model.load_state_dict(torch.load("weights/blobs_model.pth"))
# except:
#     print("Unable to load models!")

In [None]:
epochs = 100
lr = 0.01  # learning rate

In [None]:
print("---------------------Training moon model---------------------")
train_model(model=moons_model,
            X=X_moons,
            y=y_moons,
            epochs=epochs,
            lr=lr,
            loss_class=nn.BCEWithLogitsLoss,
            save_folder=save_moons,
            convert_to_video=True)

In [None]:
print("---------------------Training circle model---------------------")
train_model(model=circles_model,
            X=X_circles,
            y=y_circles,
            epochs=epochs,
            lr=lr,
            loss_class=nn.BCEWithLogitsLoss,
            save_folder=save_circles,
            convert_to_video=True)

In [None]:
print("---------------------Training blob model---------------------")
train_model(model=blobs_model,
            X=X_blobs,
            y=y_blobs,
            epochs=epochs,
            lr=lr,
            loss_class=nn.CrossEntropyLoss,
            save_folder=save_blobs,
            convert_to_video=True)

In [None]:
print("---------------------Training spiral model---------------------")
train_model(model=spirals_model,
            X=X_spirals,
            y=y_spirals,
            epochs=epochs,
            lr=lr,
            loss_class=nn.CrossEntropyLoss,
            save_folder=save_spirals,
            convert_to_video=True)

In [None]:
# save weights to reuse models

# torch.save(moons_model.state_dict(), "moons_model.pth")
# torch.save(circles_model.state_dict(), "circles_model.pth")
# torch.save(blobs_model.state_dict(), "blobs_model.pth")