In [1]:
import System.Random
import Data.List
import Data.Maybe
import Data.Ord
import Graphics.EasyPlot

In [2]:
type Sigma = Double
type Mean = Double
type Distance = Double
type Label = Int
type Point = (Double, Double)
type LabeledPoint = (Double, Double, Label)
type Centroid = (Double, Double, Label)

In [3]:

reposition :: [Double] -> Sigma -> Mean -> [Double]
reposition values sigma mu = map (\x -> sigma * x + mu) scaled
  where
    scaled = map (\x -> (x-mi)/(ma-mi) - 0.5) values
    mi = minimum values
    ma = maximum values

In [4]:
reposition [0..9] 0.5 0

[-0.25,-0.19444444444444445,-0.1388888888888889,-8.333333333333334e-2,-2.777777777777779e-2,2.777777777777779e-2,8.333333333333331e-2,0.1388888888888889,0.19444444444444442,0.25]

In [5]:
reposition [0..9] 1.0 0

[-0.5,-0.3888888888888889,-0.2777777777777778,-0.16666666666666669,-5.555555555555558e-2,5.555555555555558e-2,0.16666666666666663,0.2777777777777778,0.38888888888888884,0.5]

In [6]:
reposition [0..9] 1.0 1.0

[0.5,0.6111111111111112,0.7222222222222222,0.8333333333333333,0.9444444444444444,1.0555555555555556,1.1666666666666665,1.2777777777777777,1.3888888888888888,1.5]

In [7]:

randomClusters :: [Double] -> [Sigma] -> [Mean] -> [Mean] -> ([Double], [Double])
randomClusters values sigmas xmus ymus = (xs, ys)
  where
    clusters = genericLength sigmas
    pointsPerCluster = floor (genericLength values / (2.0 * clusters))
    chunks = chunk pointsPerCluster values
    xs = concatMap (\i -> reposition (chunks !! (2*i)) (sigmas !! i) (xmus !! i)) [0..(clusters-1)]
    ys = concatMap (\i -> reposition (chunks !! (2*i+1)) (sigmas !! i) (ymus !! i)) [0..(clusters-1)]
    chunk n [] = []
    chunk n list = genericTake n list : chunk n (genericDrop n list)

In [8]:
g <- newStdGen

In [9]:
values = take 80 (randoms g) :: [Double]

In [10]:
values

[0.46024982104050893,0.3581128628311353,0.21064922966635735,6.59847348368352e-2,0.46833308773194704,6.5654913611069565e-3,0.7710906571199988,0.8816635231623543,0.3411197833133226,0.33198411346320456,0.9391616885142349,0.5155086335318453,0.3297459395058129,0.4688596950141187,0.2786100003171603,0.2530356290126645,4.273512499336429e-2,0.36648988140360894,0.8808815495515794,0.26208544117496113,0.45574987906313247,0.42648250240338925,0.8219263014978118,0.19326670140319613,0.8110173461338968,0.312696249654978,0.31625160293419374,0.6739699993016531,0.12692985871290186,0.1515905404782959,0.453433279662825,0.8286466343753662,9.378228605309391e-2,0.9546535496836943,0.5240667892379306,0.7159292655065608,0.1334234726627499,0.8999898564085034,6.821807640004762e-2,0.44663917673639564,0.9239922971905942,0.9502924698727468,0.5496747506525373,0.523656438537601,6.737744714622063e-2,0.724858429694705,0.29710525083739847,0.7595897593080743,0.7220623461459864,0.46724750670274595,0.9532255605412281,0.865443

In [11]:
(xs, ys) = randomClusters values [0.25, 0.5, 0.75, 1.0] [1, 1, -1, -1] [1, -1, 1, -1]

In [12]:
xs

[1.0046095732113483,0.9754308542285328,0.933303107448767,0.8919750248876184,1.0069188192608451,0.875,1.0934112916427319,1.125,0.9705762325460812,0.9679663335638741,0.9865623765156415,0.9655066020842887,1.25,0.7977245915277471,1.2421517905039903,0.8836455696073022,0.8862033908135296,1.1435560723136329,0.75,0.7677415884796306,-0.6473409150415195,-0.625,-0.9653082762000104,-0.9874097622818101,-1.375,-0.8164969456646121,-1.1798556788213692,-0.7869941095597437,-0.8188721036442137,-1.0353268299351548,-0.5,-1.3139461451649692,-1.2673898810088793,-0.9073719553121642,-1.5,-0.6472742097087445,-0.8937489986297411,-0.6369053870989794,-1.1723073997395417,-0.6448857635241789]

In [13]:
ys

[1.125,1.0068494809774438,0.9550430359250969,0.9938397877086207,0.9407819850845776,0.9336496743228214,0.875,0.9652903733515665,1.1087465383852109,0.936173531973456,-1.0327167262182015,-0.821075063614928,-1.2355803324531067,-0.75,-0.992875411365663,-0.8846540675390735,-1.2132204496390695,-0.7808334305895379,-1.25,-1.0365489865074147,1.375,1.29472665008145,1.1760757112463873,1.187520929819767,0.8810854516290796,1.3293958837850377,1.3336347728730211,0.77725328915571,0.625,0.715172794254449,-1.5,-0.8609466517201134,-0.6513767448854376,-1.1858938426769772,-0.5,-0.6326198995929813,-0.5876444675426847,-1.3206423090394597,-0.5968147728182008,-0.8272535700288102]

In [14]:
plot X11 (zip xs ys)

True

In [15]:
cs = take 40 (randomRs (0,3) g)

In [16]:
cs

[2,2,1,1,2,0,1,2,1,2,3,1,2,1,0,3,2,0,0,3,2,0,2,3,0,2,0,2,3,2,2,0,0,0,2,0,3,0,1,0]

In [17]:

getCentroids :: [Double] -> [Double] -> [Label] -> [Centroid]
getCentroids xs ys cs = centroids
  where
    centroids = zip3 xmeans ymeans labels
    xmeans = means xs
    ymeans = means ys
    means ms = map (\c -> average (map (ms !!) (elemIndices c cs))) labels
    average vs =  sum vs / genericLength vs
    labels = nub cs

In [18]:
centroids = getCentroids xs ys cs

In [19]:
centroids

[(6.890442104836346e-2,0.42129944769546085,2),(0.6400270643425275,0.25210564651225703,1),(-0.35284008906320474,-0.312257331614816,0),(-1.034688832553173e-2,-0.2881738863312679,3)]

In [20]:
(ax, ay, a) = centroids !! 0

In [21]:
(bx, by, b) = centroids !! 1

In [22]:
(ax - bx)^2 + (ay - by)^2

0.35480761604213445

In [23]:

findNearestCentroid :: [Centroid] -> LabeledPoint -> Centroid
findNearestCentroid centroids observation = centroids !! index
  where
    index = fromJust (elemIndex mindistance distanceToEachCentroid)
    mindistance = minimum distanceToEachCentroid
    distanceToEachCentroid = map (distanceSqrd observation) centroids
    distanceSqrd (ax, ay, _) (bx, by, _) = (ax-bx)^2 + (ay-by)^2

In [24]:

kmeans :: [Double] -> [Double] -> [Label] -> [Label]
kmeans xs ys cs = map (\(_, _, label) -> label) nextClusters
  where
    nextClusters = map (findNearestCentroid centroids) observations
    observations = zip3 xs ys cs
    centroids = getCentroids xs ys cs


In [28]:
g <- newStdGen

In [29]:
cs = take 40 (randomRs (0,3) g)

In [30]:
kmeans xs ys cs

[2,2,2,2,2,2,2,2,2,2,3,2,3,3,2,3,3,2,3,3,0,0,1,1,1,0,1,1,1,1,3,1,1,3,1,3,1,3,1,3]

In [31]:

kmeansClustering :: [Double] -> [Double] -> [Label] -> [Label]
kmeansClustering xs ys cs
 | cs == cs' = cs
 | otherwise = kmeansClustering xs ys cs'
  where
    cs' = kmeans xs ys cs

In [32]:
kmeansClustering xs ys cs

[2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1]