In [1]:
import System.Random
import Data.List
import Data.Maybe
import Data.Ord
import Graphics.EasyPlot

In [2]:
type Sigma = Double
type Mean = Double
type Distance = Double
type Label = Int
type Point = (Double, Double)
type LabeledPoint = (Double, Double, Label)
type Centroid = (Double, Double, Label)

In [3]:

reposition :: [Double] -> Sigma -> Mean -> [Double]
reposition values sigma mu = map (\x -> sigma * x + mu) scaled
  where
    scaled = map (\x -> (x-mi)/(ma-mi) - 0.5) values
    mi = minimum values
    ma = maximum values

In [4]:
reposition [0..9] 0.5 0

[-0.25,-0.19444444444444445,-0.1388888888888889,-8.333333333333334e-2,-2.777777777777779e-2,2.777777777777779e-2,8.333333333333331e-2,0.1388888888888889,0.19444444444444442,0.25]

In [5]:
reposition [0..9] 1.0 0

[-0.5,-0.3888888888888889,-0.2777777777777778,-0.16666666666666669,-5.555555555555558e-2,5.555555555555558e-2,0.16666666666666663,0.2777777777777778,0.38888888888888884,0.5]

In [6]:
reposition [0..9] 1.0 1.0

[0.5,0.6111111111111112,0.7222222222222222,0.8333333333333333,0.9444444444444444,1.0555555555555556,1.1666666666666665,1.2777777777777777,1.3888888888888888,1.5]

In [7]:

randomClusters :: [Double] -> [Sigma] -> [Mean] -> [Mean] -> ([Double], [Double])
randomClusters values sigmas xmus ymus = (xs, ys)
  where
    clusters = genericLength sigmas
    pointsPerCluster = floor (genericLength values / (2.0 * clusters))
    chunks = chunk pointsPerCluster values
    xs = concatMap (\i -> reposition (chunks !! (2*i)) (sigmas !! i) (xmus !! i)) [0..(clusters-1)]
    ys = concatMap (\i -> reposition (chunks !! (2*i+1)) (sigmas !! i) (ymus !! i)) [0..(clusters-1)]
    chunk n [] = []
    chunk n list = genericTake n list : chunk n (genericDrop n list)

In [8]:
g <- newStdGen

In [9]:
values = take 80 (randoms g) :: [Double]

In [10]:
values

[0.9703859152498816,0.9432946947709178,0.8074359945323409,0.9481624440665145,0.46195381697186033,0.35222011287961874,5.9686054865328964e-2,0.664614801885259,0.7815948902831122,0.5626263596524446,0.21689983686254632,0.2874412410857048,0.9293646817753709,0.5292871663360826,0.9271286375347073,0.6584038089377753,0.8772962411370492,2.7134895370750645e-4,0.3352330431844749,0.9861546637587967,0.2291775449492538,0.6816691006635447,0.44120739016163946,0.8832140635676957,0.5015493737197922,0.41702326469034046,0.35167747688624806,0.27185467422545795,0.8848039311413529,0.41075841924449674,0.9017675576212282,0.19663888267956098,0.19148598021793273,0.42425367682497983,8.515296535715777e-2,0.6663752322212272,0.3456989990588544,0.7517330591880879,0.5046829257024983,0.2150959121777376,0.790359676792437,1.0631168602126762e-2,0.5310004836486982,0.29097469175388824,0.47411997021277086,0.9768631582774693,0.717483229846673,0.40680640696646175,0.33740310022159126,0.4960124332670398,0.4051573225327687,0.10884

In [11]:
(xs, ys) = randomClusters values [0.25, 0.5, 0.75, 1.0] [1, 1, -1, -1] [1, -1, 1, -1]

In [12]:
xs

[1.125,1.117563076580597,1.0802679406778615,1.1188993426511609,0.9854281936357906,0.9553047388990387,0.875,1.041061502075034,1.0731741918553028,1.0130642313304912,0.75,1.0950833929537045,0.9117002073725726,1.2487875201432241,0.9577187819975364,0.8932566807081249,0.843422057529196,0.7825468362584325,1.25,0.8884789249788061,-0.7697660733741319,-1.375,-0.971083559170854,-1.157394244228585,-1.0152348453348314,-0.625,-0.8263335807567929,-1.0674843806164103,-1.1213559980073255,-0.9982416724051931,-0.86878944347915,-0.5224431031893046,-0.5829215234563655,-1.025753697715556,-1.5,-1.3001537283137012,-0.5,-1.1405683835700018,-1.3877350658622882,-0.778757588122133]

In [13]:
ys

[0.929932588029362,0.9478204564930616,1.1105992131293314,1.0091476748409527,1.110032197690718,1.0418890349651018,1.0973957133194638,0.875,0.9599394875642535,1.125,-0.75,-1.1817389632890911,-1.184894002710651,-1.0423740803310515,-1.25,-0.894126940437945,-1.0904718479378805,-0.8418637138341898,-0.9931285429383583,-1.1704380205475438,0.9321579638062009,0.6962967207226156,0.625,1.2931602038806727,0.801221603705932,0.8671907283260685,1.3074862314590694,1.375,0.6567658749275994,1.2756701511549988,-0.952997415227447,-0.6447969662684835,-0.5310437669156237,-1.4510794141239172,-1.5,-0.5,-0.9769038764142558,-0.5025044820350462,-0.6550760566720479,-1.134519878653211]

In [14]:
plot X11 (zip xs ys)

True

In [15]:
cs = take 40 (randomRs (0,3) g)

In [16]:
cs

[0,2,1,0,1,3,3,0,0,0,3,2,1,3,3,0,0,0,2,2,3,1,0,2,1,0,1,0,3,3,3,2,3,2,3,3,1,3,2,1]

In [17]:

getCentroids :: [Double] -> [Double] -> [Label] -> [Centroid]
getCentroids xs ys cs = centroids
  where
    centroids = zip3 xmeans ymeans labels
    xmeans = means xs
    ymeans = means ys
    means ms = map (\c -> average (map (ms !!) (elemIndices c cs))) labels
    average vs =  sum vs / genericLength vs
    labels = nub cs

In [18]:
centroids = getCentroids xs ys cs

In [19]:
centroids

[(0.47516880932913436,0.4490679978682382,0),(3.2224910439671794e-2,-0.48190966293321347,2),(-0.18974120906594155,0.21616477611619353,1),(-0.26884506012046694,-0.1557723851027541,3)]

In [20]:
(ax, ay, a) = centroids !! 0

In [21]:
(bx, by, b) = centroids !! 1

In [22]:
(ax - bx)^2 + (ay - by)^2

1.0629187024747413

In [23]:

findNearestCentroid :: [Centroid] -> LabeledPoint -> Centroid
findNearestCentroid centroids observation = centroids !! index
  where
    index = fromJust (elemIndex mindistance distanceToEachCentroid)
    mindistance = minimum distanceToEachCentroid
    distanceToEachCentroid = map (distanceSqrd observation) centroids
    distanceSqrd (ax, ay, _) (bx, by, _) = (ax-bx)^2 + (ay-by)^2

In [24]:

kmeans :: [Double] -> [Double] -> [Label] -> [Label]
kmeans xs ys cs = map (\(_, _, label) -> label) nextClusters
  where
    nextClusters = map (findNearestCentroid centroids) observations
    observations = zip3 xs ys cs
    centroids = getCentroids xs ys cs


In [25]:
g <- newStdGen

In [26]:
cs = take 40 (randomRs (0,3) g)

In [27]:
kmeans xs ys cs

[1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,3,3,3,3,3,3,3,3,3,3,2,2,2,2,2,3,2,2,2,2]

In [28]:

kmeansClustering :: [Double] -> [Double] -> [Label] -> [Label]
kmeansClustering xs ys cs
 | cs == cs' = cs
 | otherwise = kmeansClustering xs ys cs'
  where
    cs' = kmeans xs ys cs

In [29]:
kmeansClustering xs ys cs

[1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,3,3,3,3,3,3,3,3,3,3,2,2,2,2,2,2,2,2,2,2]

In [30]:

closestCentroids :: [Centroid] -> (Distance, Label, Label)
closestCentroids ps = minimumBy (\(distA, _, _) (distB, _, _) ->
                                  if distA < distB then LT else GT)
                                allNonZeroPoints
  where
    allNonZeroPoints = filter (\(dist, _, _) -> dist > 0) allPoints 
    allPoints = concatMap (\(ax, ay, a) ->
                             map (\(bx, by, b) ->
                                  (distanceSqrd (ax, ay) (bx, by), a, b))
                             ps)
                          ps
    distanceSqrd (ax, ay) (bx, by) = (ax-bx)^2 + (ay-by)^2

In [31]:
centroids = getCentroids xs ys [0..39]

In [32]:
centroids

[(1.125,0.929932588029362,0),(1.117563076580597,0.9478204564930616,1),(1.0802679406778615,1.1105992131293314,2),(1.1188993426511609,1.0091476748409527,3),(0.9854281936357906,1.110032197690718,4),(0.9553047388990387,1.0418890349651018,5),(0.875,1.0973957133194638,6),(1.041061502075034,0.875,7),(1.0731741918553028,0.9599394875642535,8),(1.0130642313304912,1.125,9),(0.75,-0.75,10),(1.0950833929537045,-1.1817389632890911,11),(0.9117002073725726,-1.184894002710651,12),(1.2487875201432241,-1.0423740803310515,13),(0.9577187819975364,-1.25,14),(0.8932566807081249,-0.894126940437945,15),(0.843422057529196,-1.0904718479378805,16),(0.7825468362584325,-0.8418637138341898,17),(1.25,-0.9931285429383583,18),(0.8884789249788061,-1.1704380205475438,19),(-0.7697660733741319,0.9321579638062009,20),(-1.375,0.6962967207226156,21),(-0.971083559170854,0.625,22),(-1.157394244228585,1.2931602038806727,23),(-1.0152348453348314,0.801221603705932,24),(-0.625,0.8671907283260685,25),(-0.8263335807567929,1.307486231

In [33]:
closestCentroids centroids

(3.7528366812068293e-4,1,0)

In [34]:

relabelClosestClusters :: [LabeledPoint] -> [LabeledPoint]
relabelClosestClusters points
  | a < b     = map (relabel a b) points
  | otherwise = map (relabel b a) points
  where
    (xs, ys, cs) = unzip3 points
    centroids = getCentroids xs ys cs
    (_, a, b) = closestCentroids centroids
    relabel p q (x, y, label) = (x, y, if label == q then p else label)

In [35]:
map (\(_,_,x) -> x) $ relabelClosestClusters centroids

[0,0,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39]

In [36]:

hierarchicalCluster :: [LabeledPoint] -> Int -> [LabeledPoint]
hierarchicalCluster points n
 | n == length (nub cs') || cs == cs' = points'
 | otherwise = hierarchicalCluster points' n
 where
   points' = relabelClosestClusters points
   cs' = map (\(_, _, x) -> x) points'

In [37]:
points = hierarchicalCluster centroids 4

In [38]:
map (\(_, _, x) -> x) points

[0,0,0,0,0,0,0,0,0,0,10,10,10,10,10,10,10,10,10,10,20,20,20,20,20,20,20,20,20,20,30,30,30,30,30,30,30,30,30,30]