In [None]:
import sys
import random
from mysvm import SVM
import numpy as np
from sklearn.linear_model import SGDClassifier as SGD
from pyspark import SparkContext,SparkConf
from sklearn.base import copy

# Spark with SGDClassifier

In [None]:
def fit_in(x,svm):
    for it in x:
        X=it[:-1].reshape(1,it.shape[0]-1)
        y=np.array([it[-1]])
        svm.partial_fit(X,y,classes=[-1,1])
    yield svm
def update_in(m1,m2):
    new_mod=copy.deepcopy(m1)
    new_mod.coef_+=m2.coef_
    new_mod.intercept_+=m2.intercept_
    return new_mod

def avg_coefs_in(svm,numpart):
    svm.coef_/=numpart
    svm.intercept_/=numpart
    return svm

In [None]:
conf=SparkConf().setAppName("SVM-SGD").setMaster('local[4]')

In [None]:
sc=SparkContext(conf=conf)

In [None]:
X1=5*np.random.random((20,2))-3
y1=-1*np.ones(X1.shape[0])
X2=6*(np.random.random((20,2)))+7
y2=np.ones(X2.shape[0])
y=np.hstack((y1,y2))
X=np.vstack((X1,X2))
y=y.reshape(y.shape[0],1)
X=np.hstack((X,y))

In [None]:
svm=SGD(alpha=1,learning_rate='constant',eta0=0.01)

In [None]:
for i in range(300):
    dat=sc.parallelize(X)
    svm=dat.mapPartitions(lambda x:fit_in(x,svm)).reduce(lambda m1,m2:update_in(m1,m2))
    svm=avg_coefs_in(svm,dat.getNumPartitions())
    np.random.shuffle(X)
print(svm.coef_)
print(svm.intercept_)

## Spark with my implementation of SVM using SGD

In [None]:
def fit(x,svm):
    for it in x:
        X=it[:-1].reshape(1,it.shape[0]-1)
        y=np.array([it[-1]])
        svm.partial_fit(X,y)
    yield svm
def update(m1,m2):
    new_mod=copy.deepcopy(m1)
    new_mod.weights+=m2.weights
    new_mod.intercept+=m2.intercept
    return new_mod

def avg_coefs(svm,numpart):
    svm.weights/=numpart
    svm.intercept/=numpart
    return svm

In [None]:
max_iter=1000
l=1
alpha=0.01
sv=SVM(alpha,max_iter,l)

In [None]:
for i in range(100):
    dat=sc.parallelize(X)
    sv=dat.mapPartitions(lambda x:fit(x,sv)).reduce(lambda m1,m2:update(m1,m2))
    sv=avg_coefs(sv,dat.getNumPartitions())
    np.random.shuffle(X)
print(sv.weights)
print(sv.intercept)