In [8]:
import matplotlib.pyplot as plt
import torch
from torch import nn
import numpy as np
from putting_it_all_together import BlobClassifier, train
from sklearn.datasets import make_blobs
from sklearn.model_selection import  train_test_split

In [9]:
def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
    """Plots decision boundaries of model predicting on X in comparison to y.

    Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
    """
    # Put everything to CPU (works better with NumPy + Matplotlib)
    model.to("cpu")
    X, y = X.to("cpu"), y.to("cpu")

    # Setup prediction boundaries and grid
    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))

    # Make features
    X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()

    # Make predictions
    model.eval()
    with torch.inference_mode():
        y_logits = model(X_to_pred_on)

    # Test for multi-class or binary and adjust logits to prediction labels
    if len(torch.unique(y)) > 2:
        y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)  # mutli-class
    else:
        y_pred = torch.round(torch.sigmoid(y_logits))  # binary

    # Reshape preds and plot
    y_pred = y_pred.reshape(xx.shape).detach().numpy()
    plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
    plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())

In [10]:
# Init for testing Decision Boundary

# Init Data
NUM_CLASSES = 4
NUM_FEATURES = 2
RANDOM_SEED = 42

torch.manual_seed(42)

# 1. Create multi-class Data
X, y = make_blobs(n_samples=1000,
                    n_features=NUM_FEATURES,
                    centers=NUM_CLASSES,
                    cluster_std=1.5,
                    random_state=RANDOM_SEED)

X = torch.from_numpy(X).type(torch.float32)
y = torch.from_numpy(y).type(torch.LongTensor)

X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    test_size=0.2,
                                                    random_state=RANDOM_SEED)

# plot Dataset
# plot_scatter_datasets(X, y)

# Device
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

X_train, X_test, y_train, y_test = X_train.to(device), X_test.to(
    device), y_train.to(device), y_test.to(device)

# Create Model
model = BlobClassifier(input_features=2,
                        output_features=4,
                        #    only_linear=True,
                        hidden_units=8).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),
                            lr=1e-1)

In [11]:
epochs = 100
for epoch in range(epochs):
    train_loss, train_acc = train(model=model,
                                    X_train=X_train,
                                    y_train=y_train,
                                    loss_fn=loss_fn,
                                    optimizer=optimizer)


    if epoch % 10 == 0:
        print(
            f'Epoch: {epoch}\t|\tTrain Loss: {train_loss:.4f}\tTrain Acc: {train_acc:.2f}%')

Epoch: 0	|	Train Loss: 1.1588	Train Acc: 40.38%
Epoch: 10	|	Train Loss: 0.6448	Train Acc: 96.75%
Epoch: 20	|	Train Loss: 0.4254	Train Acc: 98.50%
Epoch: 30	|	Train Loss: 0.2529	Train Acc: 99.12%
Epoch: 40	|	Train Loss: 0.1123	Train Acc: 99.25%
Epoch: 50	|	Train Loss: 0.0663	Train Acc: 99.25%
Epoch: 60	|	Train Loss: 0.0507	Train Acc: 99.25%
Epoch: 70	|	Train Loss: 0.0430	Train Acc: 99.25%
Epoch: 80	|	Train Loss: 0.0384	Train Acc: 99.25%
Epoch: 90	|	Train Loss: 0.0352	Train Acc: 99.25%


In [12]:
model.to("cpu")
X_train, y_train = X.to("cpu"), y.to("cpu")

In [47]:
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))

In [48]:
x_min, y_min, x_max, y_max, xx.shape, yy.shape

(tensor(-12.9932),
 tensor(-11.3245),
 tensor(8.5999),
 tensor(14.8934),
 (101, 101),
 (101, 101))

In [49]:
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
X_to_pred_on.shape

torch.Size([10201, 2])

In [53]:
xxx, yyy = np.meshgrid(np.linspace(1,10,10),np.linspace(1,10,10))

In [55]:
xxx.ravel(), yyy.ravel()

(array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.,  1.,  2.,  3.,
         4.,  5.,  6.,  7.,  8.,  9., 10.,  1.,  2.,  3.,  4.,  5.,  6.,
         7.,  8.,  9., 10.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,
        10.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.,  1.,  2.,
         3.,  4.,  5.,  6.,  7.,  8.,  9., 10.,  1.,  2.,  3.,  4.,  5.,
         6.,  7.,  8.,  9., 10.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,
         9., 10.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.,  1.,
         2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]),
 array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  2.,  2.,  2.,
         2.,  2.,  2.,  2.,  2.,  2.,  2.,  3.,  3.,  3.,  3.,  3.,  3.,
         3.,  3.,  3.,  3.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.,
         4.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  6.,  6.,
         6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  7.,  7.,  7.,  7.,  7.,
         7.,  7.,  7.,  7.,  7.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  

In [57]:
np.column_stack((xxx.ravel(), yyy.ravel())) # Coordinate System의 모든 좌표.

array([[ 1.,  1.],
       [ 2.,  1.],
       [ 3.,  1.],
       [ 4.,  1.],
       [ 5.,  1.],
       [ 6.,  1.],
       [ 7.,  1.],
       [ 8.,  1.],
       [ 9.,  1.],
       [10.,  1.],
       [ 1.,  2.],
       [ 2.,  2.],
       [ 3.,  2.],
       [ 4.,  2.],
       [ 5.,  2.],
       [ 6.,  2.],
       [ 7.,  2.],
       [ 8.,  2.],
       [ 9.,  2.],
       [10.,  2.],
       [ 1.,  3.],
       [ 2.,  3.],
       [ 3.,  3.],
       [ 4.,  3.],
       [ 5.,  3.],
       [ 6.,  3.],
       [ 7.,  3.],
       [ 8.,  3.],
       [ 9.,  3.],
       [10.,  3.],
       [ 1.,  4.],
       [ 2.,  4.],
       [ 3.,  4.],
       [ 4.,  4.],
       [ 5.,  4.],
       [ 6.,  4.],
       [ 7.,  4.],
       [ 8.,  4.],
       [ 9.,  4.],
       [10.,  4.],
       [ 1.,  5.],
       [ 2.,  5.],
       [ 3.,  5.],
       [ 4.,  5.],
       [ 5.,  5.],
       [ 6.,  5.],
       [ 7.,  5.],
       [ 8.,  5.],
       [ 9.,  5.],
       [10.,  5.],
       [ 1.,  6.],
       [ 2.,  6.],
       [ 3.,