In [123]:
import numpy as np
import multiprocessing as mp

In [124]:
#load data
class DataSet:
    def __init__(self, fn):
        file = open(fn)
        data = [[float(mem) for mem in line[:-2].split(' ')] for line in file]
        self.data = np.array(data)
        self.DataCnt = self.data.shape[0]
        self.FeatCnt = self.data.shape[1]-1
        
    def GetItem(self, idx):
        assert(idx < self.DataCnt)
        label = self.data[idx][-1]
        feat = self.data[idx][:-1]
        return feat, label
    
    def GetFeat(self, idx): #ith feat
        assert(idx < self.FeatCnt)
        FeatList = self.data[:,idx]
        LabelList = self.data[:,-1]
        return FeatList, LabelList
    
    def GetFeatFull(self):
        return [self.GetFeat(idx) for idx in range(self.FeatCnt)]
    

In [125]:
#load data
train, test = DataSet('test.dat'), DataSet('train.dat')

In [126]:
#question 14
def sign(x):
    return 2*(0 < x)-1

def h(x, s, theta):
    return s*sign(x-theta)

class DecisionStumpModel:
    def __init__(self, FeatList, LabelList):
        self.FeatList = list(FeatList)
        self.LabelList = list(LabelList)
    
    def Ein(self, s, theta):
        ErrorList = [label != h(feat, s, theta) for feat, label in zip(self.FeatList, self.LabelList)]
        return np.average(ErrorList)
    
    def GetOptParams(self):
        xSet = sorted(list(FeatList))
        ThetaDomain = [-1] + [0.5*(x1+x2) for x1, x2 in zip(xSet, xSet[1:]) if x1 != x2]
        EinParamList = [(self.Ein(s, theta), s, theta) for s in [-1,1] for theta in ThetaDomain]
        print(sorted(EinParamList))
        return min(EinParamList)

In [121]:
#question 14
EinMin = 1
for idx, (FeatList, LabelList) in enumerate(train.GetFeatFull()):
    model = DecisionStumpModel(FeatList, LabelList)
    Ein, _, _ = model.GetOptParams()
    EinMin = min(EinMin, Ein)
    print('Progess: {}/{}'.format(idx+1, train.FeatCnt))
    
print('EinMin =', EinMin)

[(0.441, 1, 1.49701187655), (0.442, 1, 1.41858520225), (0.442, 1, 1.4419027785999998), (0.442, 1, 1.46725718035), (0.442, 1, 1.4920951917499998), (0.442, 1, 1.4978745339500001), (0.443, 1, 1.40453548205), (0.443, 1, 1.42476540915), (0.443, 1, 1.45973620765), (0.443, 1, 1.4790599312500001), (0.443, 1, 1.4993320997000001), (0.443, 1, 1.5215643833499999), (0.443, 1, 1.54653934945), (0.443, 1, 1.6011950334999998), (0.443, 1, 1.6302244706), (0.444, 1, 1.39288779435), (0.444, 1, 1.50575762295), (0.444, 1, 1.53894870835), (0.444, 1, 1.56058375805), (0.444, 1, 1.580320211), (0.444, 1, 1.59311136495), (0.444, 1, 1.6121072856), (0.444, 1, 1.64333383335), (0.444, 1, 1.6714539955), (0.444, 1, 1.68831421905), (0.445, -1, -1.6901236159000002), (0.445, 1, 1.3391693677), (0.445, 1, 1.37232337445), (0.445, 1, 1.38475886655), (0.445, 1, 1.5764053147000001), (0.445, 1, 1.5840238437999998), (0.445, 1, 1.65859452265), (0.445, 1, 1.68089035475), (0.445, 1, 1.6916285594), (0.446, -1, -1.69282603565), (0.446,

[(0.433, 1, -0.3336769323), (0.433, 1, -0.3202455764), (0.433, 1, -0.30465945409999995), (0.434, 1, -0.33596360999999997), (0.434, 1, -0.328767385), (0.434, 1, -0.3243964867), (0.434, 1, -0.3143829054), (0.434, 1, -0.31111313549999997), (0.434, 1, -0.29785377324999995), (0.434, 1, -0.29290491265), (0.435, 1, -0.34017299255), (0.435, 1, -0.3253236377), (0.435, 1, -0.3120240335), (0.435, 1, -0.29605260904999997), (0.435, 1, -0.29003198025), (0.435, 1, -0.27991073735), (0.436, 1, -1.46285733795), (0.436, 1, -1.30554384645), (0.436, 1, -0.51282899635), (0.436, 1, -0.34577194675), (0.436, 1, -0.28602170039999997), (0.436, 1, -0.27367317150000003), (0.437, 1, -1.50231477225), (0.437, 1, -1.4634173581), (0.437, 1, -1.45897581945), (0.437, 1, -1.31274118435), (0.437, 1, -1.29948465445), (0.437, 1, -1.0098257718), (0.437, 1, -1.0043968622000001), (0.437, 1, -0.5360496739), (0.437, 1, -0.51691033035), (0.437, 1, -0.5139765998), (0.437, 1, -0.5116592151), (0.437, 1, -0.3860353269), (0.437, 1, -0.

[(0.425, 1, -1.1309492323499999), (0.426, 1, -1.13220100175), (0.426, 1, -1.1197061967500002), (0.426, 1, -1.07694317085), (0.426, 1, -1.05727932605), (0.427, 1, -1.16315992095), (0.427, 1, -1.1343631415), (0.427, 1, -1.1022014722), (0.427, 1, -1.08051728825), (0.427, 1, -1.0679737005), (0.427, 1, -1.049474067), (0.428, 1, -1.16330020595), (0.428, 1, -1.1580294907), (0.428, 1, -1.1457926041), (0.428, 1, -1.1360654456), (0.428, 1, -1.0935016544999998), (0.428, 1, -1.0852700122), (0.428, 1, -1.0438552663), (0.429, 1, -1.30836361805), (0.429, 1, -1.2102000181500001), (0.429, 1, -1.19258056475), (0.429, 1, -1.17184615725), (0.429, 1, -1.1659937141499999), (0.429, 1, -1.15116235815), (0.429, 1, -1.1393214642), (0.429, 1, -1.0904926603999998), (0.429, 1, -1.03844339485), (0.43, -1, 1.1533797297000001), (0.43, -1, 1.1637727787499998), (0.43, -1, 1.2650044657), (0.43, 1, -1.314292271), (0.43, 1, -1.3032286154500001), (0.43, 1, -1.3018764002), (0.43, 1, -1.21130550535), (0.43, 1, -1.20339431555

[(0.429, -1, 0.9250121416), (0.43, -1, 0.92417657025), (0.43, -1, 0.9270889845500001), (0.431, -1, 0.9031084865), (0.431, -1, 0.9217879383500001), (0.431, -1, 0.9296147073000001), (0.431, -1, 0.960491412), (0.431, -1, 0.9620571842000001), (0.432, -1, 0.87718745485), (0.432, -1, 0.9022722957), (0.432, -1, 0.9049255077), (0.432, -1, 0.91423857275), (0.432, -1, 0.9381256126499999), (0.432, -1, 0.95531568015), (0.432, -1, 0.96074199435), (0.432, -1, 0.96447295845), (0.432, -1, 1.04300273125), (0.433, -1, 0.8754185488499999), (0.433, -1, 0.88448960535), (0.433, -1, 0.90195095515), (0.433, -1, 0.9073537521499999), (0.433, -1, 0.9479669489), (0.433, -1, 0.9675755724), (0.433, -1, 1.0256175512499999), (0.433, -1, 1.03131224385), (0.433, -1, 1.0384171867499998), (0.433, -1, 1.0510121932), (0.434, -1, 0.8670297353), (0.434, -1, 0.8731776802), (0.434, -1, 0.89602015225), (0.434, -1, 0.9700853030000001), (0.434, -1, 0.97746813445), (0.434, -1, 1.02372295165), (0.434, -1, 1.0280996182), (0.434, -1,

[(0.426, 1, -1.0224915101), (0.427, 1, -1.0300985887), (0.427, 1, -1.0160203553), (0.428, 1, -1.0520400747999998), (0.428, 1, -1.03714622145), (0.428, 1, -1.01083144795), (0.428, 1, -1.0029356515), (0.428, 1, -0.9964304435), (0.429, -1, 0.4682152686), (0.429, -1, 0.6562905594999999), (0.429, 1, -1.0530673960499999), (0.429, 1, -1.0458627274999999), (0.429, 1, -1.0071883234), (0.429, 1, -1.0050646589999999), (0.429, 1, -1), (0.429, 1, -0.9993418363), (0.429, 1, -0.9939500733), (0.43, -1, 0.46293977295), (0.43, -1, 0.47248703795), (0.43, -1, 0.4787988906), (0.43, -1, 0.4913830173), (0.43, -1, 0.5137741624500001), (0.43, -1, 0.65280961515), (0.43, -1, 0.65990277965), (0.43, 1, -1.056281331), (0.43, 1, -1.0058279956), (0.43, 1, -0.9919452978500001), (0.43, 1, -0.9887883179500001), (0.431, -1, 0.4563829025), (0.431, -1, 0.45924989645000003), (0.431, -1, 0.47605918095), (0.431, -1, 0.4816537634), (0.431, -1, 0.4890204788), (0.431, -1, 0.49744226810000003), (0.431, -1, 0.5109001909), (0.431, 

[(0.423, -1, 0.87778575735), (0.423, -1, 0.8842592355000001), (0.423, -1, 0.88927860145), (0.424, -1, 0.84079827235), (0.424, -1, 0.87050198845), (0.424, -1, 0.87595409475), (0.424, -1, 0.8800276851), (0.424, -1, 0.8873175233499999), (0.424, -1, 0.8915509976), (0.425, -1, 0.83506819135), (0.425, -1, 0.8500094162), (0.425, -1, 0.8657967740000001), (0.425, -1, 0.8748388506), (0.425, -1, 0.8926598483), (0.426, -1, 0.82501814695), (0.426, -1, 0.8292705579999999), (0.426, -1, 0.8621577170000001), (0.426, -1, 0.8953150232), (0.427, -1, 0.8235975231499999), (0.427, -1, 0.82760631695), (0.427, -1, 0.90186584515), (0.428, -1, 0.7039539385), (0.428, -1, 0.8209955856), (0.428, -1, 0.90908148785), (0.429, -1, 0.6924965944999999), (0.429, -1, 0.6973328491499999), (0.429, -1, 0.701292145), (0.429, -1, 0.7060507752), (0.429, -1, 0.7196789871), (0.429, -1, 0.7968572794), (0.429, -1, 0.81307516335), (0.429, -1, 0.8179595641499999), (0.429, -1, 0.9131513395), (0.429, -1, 0.91973990855), (0.43, -1, 0.686

[(0.409, 1, -0.6633739839), (0.41, 1, -0.7162279728500001), (0.41, 1, -0.66550209865), (0.41, 1, -0.6577215471), (0.41, 1, -0.4924276599), (0.41, 1, -0.49108969020000004), (0.41, 1, -0.48937059169999997), (0.41, 1, -0.4852434997), (0.411, 1, -0.7384110591999999), (0.411, 1, -0.72944676265), (0.411, 1, -0.7189768622), (0.411, 1, -0.7134546937499999), (0.411, 1, -0.70961990215), (0.411, 1, -0.69754433645), (0.411, 1, -0.6819074708499999), (0.411, 1, -0.6682301859499999), (0.411, 1, -0.65197365605), (0.411, 1, -0.5342175807), (0.411, 1, -0.49268893565), (0.411, 1, -0.491883364), (0.411, 1, -0.490589649), (0.411, 1, -0.48716759625), (0.411, 1, -0.48405498290000004), (0.411, 1, -0.48126497960000003), (0.412, 1, -0.7416962753), (0.412, 1, -0.7357929653499999), (0.412, 1, -0.72157693445), (0.412, 1, -0.7119241124), (0.412, 1, -0.706175774), (0.412, 1, -0.6983641082500001), (0.412, 1, -0.6947535169000001), (0.412, 1, -0.6827982985), (0.412, 1, -0.6816066174), (0.412, 1, -0.67442935835), (0.412

[(0.398, 1, -0.62000697185), (0.398, 1, -0.6158785932), (0.398, 1, -0.61371479905), (0.399, 1, -0.62239751365), (0.399, 1, -0.61759701695), (0.399, 1, -0.6148046908), (0.399, 1, -0.6116176596), (0.399, 1, -0.56724756555), (0.399, 1, -0.5507174825500001), (0.4, 1, -0.62875821645), (0.4, 1, -0.6239571037), (0.4, 1, -0.60972170595), (0.4, 1, -0.5752590369499999), (0.4, 1, -0.55897670945), (0.4, 1, -0.5535511104499999), (0.4, 1, -0.5488932789000001), (0.401, 1, -0.6919233408000001), (0.401, 1, -0.63206685155), (0.401, 1, -0.62576918305), (0.401, 1, -0.607492394), (0.401, 1, -0.5840792448000001), (0.401, 1, -0.57901407185), (0.401, 1, -0.5557863849), (0.401, 1, -0.5468220279), (0.402, 1, -0.69781739445), (0.402, 1, -0.6922224075000001), (0.402, 1, -0.68946787685), (0.402, 1, -0.6527768257), (0.402, 1, -0.64419313365), (0.402, 1, -0.6344056547), (0.402, 1, -0.59568041375), (0.402, 1, -0.5816072628), (0.402, 1, -0.5437922266499999), (0.402, 1, -0.5299852216000001), (0.403, 1, -0.7355823054), 

[(0.426, 1, -0.48590203060000003), (0.426, 1, -0.46901451315), (0.426, 1, -0.4658704299), (0.426, 1, -0.46263046335), (0.426, 1, -0.42818721270000004), (0.426, 1, -0.41021743860000004), (0.426, 1, -0.2939086516), (0.427, 1, -0.5295111839), (0.427, 1, -0.487279284), (0.427, 1, -0.48394581535000003), (0.427, 1, -0.48063441090000003), (0.427, 1, -0.47088970295), (0.427, 1, -0.46803932345), (0.427, 1, -0.46444804195), (0.427, 1, -0.46072699935), (0.427, 1, -0.45761742325), (0.427, 1, -0.428606261), (0.427, 1, -0.42786904815000004), (0.427, 1, -0.41104241404999997), (0.427, 1, -0.40952062115), (0.427, 1, -0.4069022335), (0.427, 1, -0.4018643414), (0.427, 1, -0.39271821725), (0.427, 1, -0.3921651095), (0.427, 1, -0.37992930820000004), (0.427, 1, -0.2952444418), (0.427, 1, -0.29273512815), (0.427, 1, -0.29093541815), (0.428, 1, -0.5341318581000001), (0.428, 1, -0.5276707851), (0.428, 1, -0.49132577354999996), (0.428, 1, -0.48740414755), (0.428, 1, -0.48266658665), (0.428, 1, -0.47884553350000

In [110]:
print(a.GetOptParams())

(0.441, 1, 1.49701187655)


In [None]:
#question 14
def mpEin(FeatList, LabelList):
    model = DecisionStumpModel(FeatList, LabelList)
    Ein, _, _ = model.GetOptParams()
    return Ein

pool = mp.Pool(mp.cpu_count())
EinList = pool.map(mpEin, train.GetFeatFull())
print(EinList)