# Implementation of SVM for Classification

## From ML Algorithms to GenAI & LLMs by Aman Kharwal

Here is Support Vector Machines (SVM) implementation using Python for solving classification problems on the popular Iris data

In [7]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
import plotly.graph_objects as go

# Load the Iris dataset
iris = load_iris()
X = iris.data[:, :2] # Consider only the first two features for visualization
y = iris.target

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state=42)

# Create an instance of the SVM classifer
svm = SVC(kernel = 'linear')

# Train the SVM classifier
svm.fit(X_train, y_train)

# Evaluate the model on the test set
accuracy = svm.score(X_test, y_test)
print("Accuracy:", accuracy)

Accuracy: 0.9


Now here is how to visualize the ecision boundary of the SVM model we have trained above

In [None]:
import plotly.io as pio
pio.renderers.default = "browser"

# Define the decision boundary
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = 0.02 # Step size for the meshgrid

xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# Predict the class labels for the meshgrid points
Z = svm.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# Visualize the data points and the decision boundary using Plotly
fig = go.Figure()
fig.add_trace(go.Scatter(x=X[:, 0], y=X[:, 1], mode = 'markers',
                         marker = dict(color = y, colorscale='RdYlBu', size = 8), 
                         text=[iris.target_names[i] for i in y]))

fig.add_trace(go.Contour(x=np.arange(x_min, x_max, h), 
                         y = np.arange(y_min, y_max, h), z=Z,
                         showscale=False, opacity = 0.8))
fig.update_layout(title='SVM Decision Boundary (Iris Dataset)', 
                  xaxis_title = 'Sepal Length (cm)', 
                  yaxis_title = 'Sepal Width (cm)', showlegend=False)
fig.show()
