In [25]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pdb
import torch

In [27]:
def split_data(X, y, random_state=1):
    X_train = X.sample(15, random_state=random_state).sort_index().copy()
    y_train = y[X.isin(X_train).all(axis=1)]
    
    X_test = X.loc[~X.isin(X_train).all(axis=1),:]
    y_test = y[X.isin(X_test).all(axis=1)]
    
    return X_train, y_train, X_test, y_test

In [40]:
def plot_boundary(model, X, y, x1_min=-20, x1_max=20):
    w1,w2 = model.weights[0], model.weights[1]
    b = model.bias

    x2_min = (-(w1*x1_min)-b)/w2
    x2_max = (-(w1*x1_max)-b)/w2
    
    plt.scatter(X[y==0].x1, X[y==0].x2, c='blue')
    plt.scatter(X[y==1].x1, X[y==1].x2, c='red')
    plt.plot([x1_min,x1_max], [x2_min, x2_max], c='black', linestyle='--')
    plt.xlim([-5,5])
    plt.ylim([-5,5])
    plt.show()

In [26]:
url = 'https://raw.githubusercontent.com/Lightning-AI/dl-fundamentals/main/unit01-ml-intro/1.6-perceptron-in-python/perceptron_toydata-truncated.txt'
df = pd.read_csv(url, sep='\t')

In [37]:
X_train, y_train, X_test, y_test = split_data(df.iloc[:,0:2], df.iloc[:,2])
X_train, y_train, X_test, y_test = torch.from_numpy(X_train.values), torch.from_numpy(y_train.values), torch.from_numpy(X_test.values), torch.from_numpy(y_test.values)
X_train, y_train, X_test, y_test = X_train.to(torch.float32), y_train.to(torch.float32), X_test.to(torch.float32), y_test.to(torch.float32) 

In [43]:
class Perceptron:
    def __init__(self, num_features):
        self.num_features = num_features
        self.weights = torch.from_numpy(np.random.standard_normal(num_features))
        self.bias = torch.tensor(0.)

    def forward(self, x):
        weighted_sum_z = torch.dot(x, self.weights) + self.bias
        prediction = torch.tensor(1.) if weighted_sum_z > 0. else torch.tensor(0)

        return prediction

    def update(self, x, true_y):
        prediction = self.forward(x)
        error = true_y - prediction

        self.bias += error
        self.weights += error * x

        return error