In [1]:
import qualified Data.ByteString.Lazy as BS
import qualified Data.Csv.HMatrix as CSV
import Numeric.LinearAlgebra
import Data.Csv (HasHeader (..))
import Numeric.LinearAlgebra.Data
import Prelude hiding ((<>))
import Data.Function ((&))
import Data.Foldable (foldl')
import Data.Vector.Storable (ifoldl', generate)
import qualified Data.Vector.Storable as V
import Control.Monad (foldM)

In [2]:
hypothesis :: Matrix Double -> Vector Double -> Double -> Double
hypothesis values weights bias =
  sumElements (values * asRow weights) + bias
  
activationSigmoid :: Double -> Double
activationSigmoid value =
  1 / (1 + exp(-value))
  
sigmoidNeuron :: Vector Double -> Double -> Matrix Double -> Double
sigmoidNeuron weights bias values =
  activationSigmoid $ hypothesis values weights bias  
  
sigmoidQuantizer :: Vector Double -> Vector Double
sigmoidQuantizer = cmap (fromIntegral . round)

In [3]:
binaryCrossentropyLoss :: Vector Double -> Vector Double -> Double
binaryCrossentropyLoss truths preds =
  let xs = truths * cmap log preds + (1 - truths) * cmap log (1 - preds)
  in (-1) * sumElements xs / fromIntegral (size xs)

In [4]:
sigmoidDeltaRule :: 
     Matrix Double  -- ^ train_x
  -> Vector Double  -- ^ train_y
  -> Double         -- ^ learning_rate
  -> Vector Double  -- ^ weights
  -> Double         -- ^ bias
  -> (Vector Double, Double)
sigmoidDeltaRule trainX trainY lr ws b =
  ifoldl' (\(ws', b') i y -> go (trainX ? [i]) y ws' b') (ws, b) trainY
  where
    go :: Matrix Double -> Double -> Vector Double -> Double -> (Vector Double, Double)
    go x y ws' b'  = 
      let o = sigmoidNeuron ws' b' x
          e = y - o
          change = x * scalar e
          ws'' = ws' + (scalar lr * flatten change)
          b'' = b' + lr * e
      in (ws'', b'')
          

In [5]:
-- Some known data to validate learning

weights :: Vector Double
weights = fromList [0.09762701, 0.43037873]

bias :: Double
bias = 0.20552675214328775

inputs :: Matrix Double
inputs = fromLists [[-7.66054695e-01,  1.83324682e-01],
       [-9.20383253e-01, -7.23168038e-02],
       [-9.86585088e-01, -2.86920000e-01],
       [ 1.70910242e+00, -1.10453952e+00],
       [ 1.98764670e+00,  1.77624479e+00],
       [ 3.86274219e+00,  2.63325914e+00],
       [-1.12836011e+00, -4.22761581e-01],
       [-1.10074198e+00, -2.56042975e+00],
       [-1.53716448e+00,  1.10502647e+00],
       [-3.23726922e-01,  5.56269743e-01],
       [-1.28532883e+00, -1.30819171e+00],
       [ 3.35973253e+00, -1.79506345e+00],
       [-1.27034986e+00,  1.26780440e+00],
       [-7.10233633e-01, -1.13058206e+00],
       [-1.12933108e+00,  6.87661760e-01],
       [ 2.48206729e-01, -5.09792713e-01],
       [-2.47263494e+00, -4.86612462e-01],
       [-1.11573423e+00,  1.43370121e+00],
       [-1.21414740e+00,  1.97698901e+00],
       [-1.25860859e+00, -1.82896522e-01],
       [-5.35834091e-01,  1.10698637e+00],
       [ 9.23281451e-01, -1.30813451e+00],
       [ 2.02751248e+00, -4.03257104e-01],
       [ 2.18718140e+00,  2.03196825e+00],
       [ 5.12814562e-01,  4.32994532e-01],
       [-2.81180710e+00, -3.45538051e-01],
       [ 3.13380666e-01,  1.12073484e+00],
       [ 2.72990606e-01, -3.21105367e-01],
       [-1.18986266e+00,  4.24345081e-01],
       [-1.12132974e-01, -9.07197428e-01],
       [ 1.23567148e+00,  1.55525060e+00],
       [-7.03291920e-01, -6.05415797e-01],
       [ 1.33828180e+00, -9.86132567e-01],
       [-3.19826339e+00, -1.25732069e+00],
       [ 7.64389529e-01, -6.79598011e-01],
       [ 9.52529622e-01, -7.93470192e-01],
       [-4.37933163e-01, -1.24378126e+00],
       [ 2.40620516e+00, -1.00171129e-01],
       [ 6.81520677e-01,  3.93906076e-01],
       [ 4.51394467e-01, -2.47553402e-03],
       [-1.45709006e+00,  4.86681188e-01],
       [ 1.69989125e+00, -1.66130052e+00],
       [-9.80358459e-01, -1.40246886e+00],
       [-2.68225264e-01,  4.58931008e-02],
       [-1.85016853e+00, -3.58754622e+00],
       [-1.07894567e-01,  1.34057624e+00],
       [ 1.27458364e+00, -2.52159550e+00],
       [-4.47369310e-01, -2.68051210e-01],
       [-8.11815767e-03,  2.11564734e+00],
       [-2.22244349e+00, -1.62073375e+00],
       [-8.94365876e-01,  9.29950318e-01],
       [-2.41527100e-01,  4.55946498e-01],
       [ 8.50411665e-01, -2.08311803e-01],
       [-1.06938289e+00,  4.17180364e-01],
       [-9.95105317e-01,  1.23195055e+00],
       [-8.31839865e-01,  7.21669496e-01],
       [-1.24743076e+00,  7.09216593e-01],
       [ 1.22964368e+00, -9.61541555e-01],
       [-9.00142145e-01,  1.78037474e+00],
       [ 5.71068197e-01, -1.23267396e+00],
       [ 2.16543447e-01, -7.15602562e-02],
       [-1.59027524e-01, -2.38076394e+00],
       [-5.97321845e-01, -1.20114435e+00],
       [ 1.04421447e+00,  2.02899023e+00],
       [ 9.55083877e-01, -1.62184212e+00],
       [-9.03343950e-01,  1.79445113e+00],
       [-1.27395967e+00, -4.37843295e-02],
       [ 2.50202269e-01, -1.04862023e-01],
       [-8.95526209e-01,  3.51173413e-01],
       [-1.17921312e+00,  5.44818813e-01],
       [-1.20374176e+00, -2.89788250e-01],
       [ 1.22006997e+00, -8.27182474e-01],
       [ 2.34137626e+00,  1.47049892e+00],
       [-6.15293010e-01, -8.64008129e-01],
       [ 1.79574591e+00,  1.87834887e+00],
       [-7.80910684e-01,  1.85082392e+00],
       [-1.36508549e-01,  7.33899962e-01],
       [ 1.73590335e+00, -8.69536952e-01],
       [-6.35812195e-01,  1.35329628e+00],
       [-6.83049356e-01, -8.00221059e-01],
       [ 8.51440362e-01,  2.42548085e+00],
       [-1.11745336e+00,  1.26661394e+00],
       [ 6.73114962e-01,  5.59313196e-01],
       [-1.14855777e+00,  8.65825460e-01],
       [-1.45125944e+00, -7.62884416e-01],
       [ 8.56699698e-01, -4.43584187e-01],
       [ 1.91002859e+00, -2.20119016e+00],
       [-1.04999632e+00,  1.83240861e+00],
       [-7.81871442e-02, -3.54528872e-01],
       [ 5.23902618e-01, -1.97242756e+00],
       [-1.36672011e+00, -1.05286598e+00],
       [ 2.17261785e-01, -4.71824167e-01],
       [ 1.35081491e+00, -1.94643859e-01],
       [-1.41363563e+00, -3.59967304e-01],
       [ 1.86752028e+00, -2.05721672e-01],
       [-9.89449288e-01,  3.73436160e-01],
       [-8.69385004e-01, -8.60699619e-01],
       [ 2.53026908e+00,  3.80251566e-01],
       [-4.61220130e-01, -3.69743600e+00],
       [-2.05321581e+00,  4.22341441e-01]]
       
outputs :: Vector Double
outputs = fromList [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1,
       1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1,
       0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0,
       0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0,
       1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]
       

In [6]:
epochs = 40
learningRate = 0.1

data Epoch = Epoch
  { epochAccuracy :: Double
  , epochLoss     :: Double
  , epochWeights  :: Vector Double
  , epochBias     :: Double
  } deriving Show

epoch ws b = 
  let
    (ws', b') = sigmoidDeltaRule inputs outputs learningRate ws b
    scores = generate (rows inputs) (\i -> sigmoidNeuron ws' b' (inputs ? [i]))
    prediction = sigmoidQuantizer scores
  in Epoch
      { epochAccuracy = sumElements (V.zipWith (*) prediction outputs) / fromIntegral (size outputs) 
      , epochLoss     = binaryCrossentropyLoss outputs scores
      , epochWeights  = ws'
      , epochBias     = b'
      }

In [7]:
runEpoch ws b = do
  let e = epoch ws b
  print e
  pure (epochWeights e, epochBias e)

In [8]:
runEpoch weights bias

Epoch {epochAccuracy = 0.42, epochLoss = 0.3049197828843639, epochWeights = [1.8017761337143614,0.1826053291481143], epochBias = 0.2150211159369129}
([1.8017761337143614,0.1826053291481143],0.2150211159369129)

In [9]:
foldM (\(ws, b) _ -> runEpoch ws b) (weights, bias) [1..40] 

Epoch {epochAccuracy = 0.42, epochLoss = 0.3049197828843639, epochWeights = [1.8017761337143614,0.1826053291481143], epochBias = 0.2150211159369129}
Epoch {epochAccuracy = 0.42, epochLoss = 0.28241583548779375, epochWeights = [2.2577378595154616,0.26189666139534695], epochBias = 0.37499645762186795}
Epoch {epochAccuracy = 0.43, epochLoss = 0.2735595893222411, epochWeights = [2.525092161558543,0.325508471936362], epochBias = 0.49971852558569835}
Epoch {epochAccuracy = 0.43, epochLoss = 0.2691475501806344, epochWeights = [2.7092401535545396,0.3717292746089176], epochBias = 0.5919080781848157}
Epoch {epochAccuracy = 0.43, epochLoss = 0.2666993139281294, epochWeights = [2.845117332137158,0.4061081452188554], epochBias = 0.6612890913104311}
Epoch {epochAccuracy = 0.44, epochLoss = 0.2652472447154231, epochWeights = [2.9492208981416295,0.43242186790087767], epochBias = 0.7147331926489454}
Epoch {epochAccuracy = 0.45, epochLoss = 0.2643457915371622, epochWeights = [3.0308992217540016,0.453018

### Expected output
```
epoch: 1, loss: 0.30491978287317717 accuracy: 0.9, weights: [1.80177613 0.18260533], bias: 0.21502111594549478
epoch: 2, loss: 0.2824158354765139 accuracy: 0.9, weights: [2.25773786 0.26189666], bias: 0.37499645765349787
epoch: 3, loss: 0.2735595893083779 accuracy: 0.91, weights: [2.52509216 0.32550847], bias: 0.4997185256367553
epoch: 4, loss: 0.26914755016451136 accuracy: 0.91, weights: [2.70924015 0.37172927], bias: 0.5919080782538896
epoch: 5, loss: 0.26669931391025636 accuracy: 0.91, weights: [2.84511733 0.40610815], bias: 0.6612890913979244
epoch: 6, loss: 0.26524724469621735 accuracy: 0.92, weights: [2.9492209  0.43242187], bias: 0.7147331927555293
epoch: 7, loss: 0.2643457915169445 accuracy: 0.93, weights: [3.03089922 0.45301878], bias: 0.7567104522631428
epoch: 8, loss: 0.26376732094399513 accuracy: 0.93, weights: [3.09604158 0.46940962], bias: 0.7901835288541802
epoch: 9, loss: 0.26338669669361087 accuracy: 0.93, weights: [3.14861879 0.48261436], bias: 0.8171870785313552
epoch: 10, loss: 0.26313129205966246 accuracy: 0.94, weights: [3.1914376  0.49335164], bias: 0.8391679208642522
epoch: 11, loss: 0.2629571741258358 accuracy: 0.94, weights: [3.22655215 0.50214554], bias: 0.8571862369992933
epoch: 12, loss: 0.26283689548252687 accuracy: 0.94, weights: [3.25550636 0.50938866], bias: 0.8720384834630405
epoch: 13, loss: 0.26275286320970326 accuracy: 0.94, weights: [3.27948519 0.5153815 ], bias: 0.8843353236766414
epoch: 14, loss: 0.26269356544432027 accuracy: 0.94, weights: [3.29941343 0.520358  ], bias: 0.8945528366562984
epoch: 15, loss: 0.2626513410055058 accuracy: 0.94, weights: [3.31602263 0.52450281], bias: 0.9030672752847765
epoch: 16, loss: 0.262621019262518 accuracy: 0.94, weights: [3.32989795 0.52796334], bias: 0.9101793474379831
epoch: 17, loss: 0.26259906893721024 accuracy: 0.94, weights: [3.34151167 0.53085837], bias: 0.9161316126382242
epoch: 18, loss: 0.26258305376623253 accuracy: 0.94, weights: [3.35124785 0.53328435], bias: 0.921121225436274
epoch: 19, loss: 0.2625712780208495 accuracy: 0.94, weights: [3.3594208  0.53532006], bias: 0.9253094527680777
epoch: 20, loss: 0.2625625520975088 accuracy: 0.94, weights: [3.36628902 0.53703027], bias: 0.9288289032249266
epoch: 21, loss: 0.2625560354764207 accuracy: 0.94, weights: [3.37206608 0.5384684 ], bias: 0.9317890999378903
epoch: 22, loss: 0.2625511303351723 accuracy: 0.94, weights: [3.37692904 0.5396787 ], bias: 0.9342808321451919
epoch: 23, loss: 0.2625474087848395 accuracy: 0.94, weights: [3.38102517 0.54069796], bias: 0.936379591263977
epoch: 24, loss: 0.26254456268738147 accuracy: 0.94, weights: [3.38447724 0.54155682], bias: 0.9381483104898026
epoch: 25, loss: 0.26254236879363296 accuracy: 0.94, weights: [3.38738782 0.54228086], bias: 0.9396395674944606
epoch: 26, loss: 0.26254066436757084 accuracy: 0.94, weights: [3.38984278 0.54289149], bias: 0.9408973683210835
epoch: 27, loss: 0.262539330043142 accuracy: 0.94, weights: [3.39191411 0.54340665], bias: 0.9419586011532618
epoch: 28, loss: 0.26253827770314436 accuracy: 0.94, weights: [3.39366223 0.54384138], bias: 0.9428542274296596
epoch: 29, loss: 0.2625374418660702 accuracy: 0.94, weights: [3.39513789 0.54420834], bias: 0.943610262266029
epoch: 30, loss: 0.26253677353641447 accuracy: 0.94, weights: [3.39638381 0.54451814], bias: 0.9442485846453914
epoch: 31, loss: 0.26253623579337104 accuracy: 0.94, weights: [3.39743592 0.54477974], bias: 0.9447876091979449
epoch: 32, loss: 0.2625358006117872 accuracy: 0.94, weights: [3.39832448 0.54500067], bias: 0.9452428448241198
epoch: 33, loss: 0.2625354465603247 accuracy: 0.94, weights: [3.39907502 0.54518727], bias: 0.945627360364407
epoch: 34, loss: 0.26253515712664954 accuracy: 0.94, weights: [3.39970902 0.54534489], bias: 0.9459521735964433
epoch: 35, loss: 0.26253491949263774 accuracy: 0.94, weights: [3.40024463 0.54547805], bias: 0.9462265767627037
epoch: 36, loss: 0.26253472363387975 accuracy: 0.94, weights: [3.40069714 0.54559055], bias: 0.9464584093971398
epoch: 37, loss: 0.2625345616538668 accuracy: 0.94, weights: [3.40107948 0.5456856 ], bias: 0.9466542872768524
epoch: 38, loss: 0.26253442728874693 accuracy: 0.94, weights: [3.40140253 0.54576591], bias: 0.9468197947642695
epoch: 39, loss: 0.26253431553662876 accuracy: 0.94, weights: [3.40167551 0.54583377], bias: 0.9469596465432697
epoch: 40, loss: 0.2625342223782718 accuracy: 0.94, weights: [3.40190619 0.54589111], bias: 0.9470778237260611
```