<a href="https://colab.research.google.com/github/HBShim03/HBShim03.github.io/blob/main/MulticlassClassificationModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multiclass Classification

In [None]:
# Importing Plotting Helper Functions

import requests
from pathlib import Path
if Path("helper_functions.py").is_file():
  print("helper_functions.py already exists, skipping download")
else:
  print("Downloading helper_functions.py")
  request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
  with open("helper_functions.py", "wb") as f:
    f.write(request.content)
from helper_functions import plot_predictions, plot_decision_boundary


# Defining Accuracy Function
def accuracy_fn(y_true, y_pred):
  correct = torch.eq(y_true, y_pred).sum().item() # How many are correct
  acc = (correct/len(y_pred))*100
  return acc


# Dataset creation and splitting training and testing dataset

import torch
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from torch import nn

NUM_CLASSES = 4
NUM_FEATURES = 2
RANDOM_SEED=42

X_blob, y_blob = make_blobs(n_samples = 1000,
                            n_features = NUM_FEATURES,
                            centers = NUM_CLASSES,
                            cluster_std = 1.5,
                            random_state = RANDOM_SEED)
X_blob, y_blob = torch.from_numpy(X_blob).type(torch.float),
torch.from_numpy(y_blob).type(torch.LongTensor)

X_blob_train, X_blob_test, y_blob_train, y_blob_test = train_test_split(X_blob,
                                                                        y_blob,
                                                                        test_size = 0.8,
                                                                        random_state=RANDOM_SEED)

plt.figure(figsize=(10,7))
plt.scatter(X_blob[:, 0], X_blob[:,1], c=y_blob, cmap = plt.cm.RdYlBu)


# Model Creation
class MulticlassModel(nn.Module):
  def __init__(self, in_feature, out_feature, hidden_units = 8):
    super().__init__()
    self.linear_layer_stack = nn.Sequential(
        nn.Linear(in_features = in_feature, out_features = hidden_units),
        #nn.ReLU(),
        nn.Linear(in_features = hidden_units, out_features = hidden_units),
        #nn.ReLU(),
        nn.Linear(in_features = hidden_units, out_features = out_feature)
    )
  def forward(self,x):
    return self.linear_layer_stack(x)

MCModel = MulticlassModel(in_feature = 2, out_feature = 4)
MCModel

# Result before training
MCModel.eval()
with torch.inference_mode():
  y_untrained_logits = MCModel(X_blob_test)
  y_untrained_preds = torch.softmax(y_untrained_logits, dim=1).argmax(dim=1)
plt.figure(figsize = (12,6))
plt.subplot(1,2,1)
plt.title("Before Training")
plot_decision_boundary(MCModel, X_blob_test, y_blob_test)


# Training
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(MCModel.parameters(), lr=0.1)
epochs = 1000

for epoch in range(epochs):
  MCModel.train()
  y_logits = MCModel(X_blob_train)
  y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)
  loss = loss_fn(y_logits, y_blob_train)
  train_acc = accuracy_fn(y_blob_train, y_pred)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  MCModel.eval()
  with torch.inference_mode():
    test_logits = MCModel(X_blob_test)
    test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1)
    test_loss = loss_fn(test_logits, y_blob_test)
    test_acc = accuracy_fn(y_blob_test, test_pred)

  if epoch % 100 == 0:
    print(f"Epoch: {epoch} | Loss: {loss:.4f} | Training Accuracy: {train_acc:.2f}% | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

# Result After Training
MCModel.eval()
with torch.inference_mode():
  y_logits = MCModel(X_blob_test)
  y_preds = torch.softmax(y_logits, dim=1).argmax(dim=1)
plt.figure(figsize = (12,6))
plt.subplot(1,2,1)
plt.title("After Training")
plot_decision_boundary(MCModel, X_blob_test, y_blob_test)