In [1]:
%matplotlib inline
import sys, os, pdb, warnings
sys.path.insert(0, './core/')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

np.set_printoptions(suppress=True, linewidth=120, precision=4)
pd.set_option('display.max_columns', 15)
pd.set_option('display.width', 1000)

In [2]:
from sklearn.cluster import KMeans
from models.ours.Gaussians import MultivariateGaussain, MixMG

class MixMGLearner:
    def __init__(self, n_components = 2, reg_covar = 1e-6, tol = 1e-3, max_iter = 100):
        self.n_components = n_components
        self.reg_covar = reg_covar
        self.tol = tol
        self.max_iter = max_iter
    
    def fit(self, train, weight = None):
        N, D = train.shape
        if weight is None:
            weight = np.ones( shape = (N, ) )
        
        # init the component weights
        self.w = np.ones( shape = (self.n_components, ) ) / self.n_components
        
        # use kmeans to find the initial center 
        clf = KMeans(n_clusters = self.n_components, init='k-means++', n_init = 10).fit(train)
        
        sub_models = []
        for i in range(self.n_components):
            mg = MultivariateGaussain()
            mg.mu = clf.cluster_centers_[i,:]
            mg.S = np.identity(D)
            sub_models.append(mg)
        
        self.mgs = sub_models
        
        # define the Q, Q[i,j] the probility that sample i falls into j component
        self.Q = np.ones(shape = (N, self.n_components)) / self.n_components
        
        # define the V, V[i,j] is the density of sample i under j component
        self.V = np.ones(shape = (N, self.n_components)) / self.n_components
        
        # need to update convergence criteria
        n_iter = 0
        converged = False
        while not converged:
            n_iter += 1
            
            # compute V for the e step
            for i in range(self.n_components):
                masses = self.mgs[i].mass(train, logmode = 1)
                # originally, it is power
                masses = masses * weight.flatten()
                self.V[:,i] = masses
            
            self._estep()
            self._mstep(train, weight)
            if n_iter >= self.max_iter:
                break
        return self
    
    def _estep(self):
        w = self.w.reshape(1, -1)
        self.Q = (self.V + np.log(w))
        # basically numertical stable softmax here
        self.Q -= np.max(self.Q, axis = 1, keepdims = True)
        self.Q = np.exp(self.Q)
        row_sum = self.Q.sum(axis = 1, keepdims = True)
        self.Q = self.Q / row_sum
        
    def _mstep(self, data, weight):
        # update w
        self.w = self.Q.mean(axis=0)
        
        # update mu 
        weight = weight.reshape(-1, 1)
        wQ = weight * self.Q
        
        Qcol_sum = wQ.sum(axis = 0)
        for i in range(self.n_components):
            self.mgs[i].mu = np.sum( wQ[:, i:i+1] * data, axis = 0) / Qcol_sum[i]
        
        # update cov matrix 
        for i in range(self.n_components):
            mu = self.mgs[i].mu.reshape(1, -1)
            mat = data - mu
            mat2 = mat.copy()
            mat2 = wQ[:, i:i+1] * mat2 / (self.w[i] * data.shape[0])
            S = mat.T @ mat2
            S += np.identity(mu.size) * self.reg_covar
            self.mgs[i].S = S
        
    def get_model(self):
        model = MixMG()
        model.W = self.w
        model.models = self.mgs
        return model


In [3]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.utils.extmath import row_norms
from sklearn.datasets._samples_generator import make_blobs
from timeit import default_timer as timer

X, _ = make_blobs(n_samples=4000, centers=5, cluster_std=2.0, n_features = 15)
md1 = MixMG().fit(X, n_comps = 5)
print(np.percentile(md1.mass(X, logmode = 1), [10,25,50,75,90]))

[-36.7454 -34.842  -32.8224 -31.1787 -29.9868]


In [4]:
lm = MixMGLearner(n_components = 5).fit(X)
md2 = lm.get_model()
print(np.percentile(md2.mass(X, logmode = 1), [10,25,50,75,90]))

[-36.7454 -34.842  -32.8224 -31.1787 -29.9868]


In [6]:
for i in range(5):
    print(md1.models[i].mu)
    print(md2.models[i].mu)

[-5.7519  0.0906  2.7573 -3.9051 -6.4964  8.8368 -3.9939  9.1159  8.6867 -2.9059 -8.6172 -3.16    8.9408  0.3486
  6.2774]
[-1.4998  8.1111  1.6714  8.7511 -0.6194 -4.39   -4.75   -2.1825  3.6725 -2.6952  9.1607  1.2819  5.6106  0.6879
  8.5418]
[-1.4998  8.1111  1.6714  8.7511 -0.6194 -4.39   -4.75   -2.1825  3.6725 -2.6952  9.1607  1.2819  5.6106  0.6879
  8.5418]
[-5.7519  0.0906  2.7573 -3.9051 -6.4964  8.8368 -3.9939  9.1159  8.6867 -2.9059 -8.6172 -3.16    8.9408  0.3486
  6.2774]
[-3.3451 -5.527   9.924   2.5989  0.5635  6.4402  0.5545  5.1935 -2.4956  7.7788  9.8388 -1.5097 -2.0615 -7.1917
  0.851 ]
[-3.3451 -5.527   9.924   2.5989  0.5635  6.4402  0.5545  5.1935 -2.4956  7.7788  9.8388 -1.5097 -2.0615 -7.1917
  0.851 ]
[ 6.4255 -4.8286 -0.4609 -5.0437 -5.5984  7.2656 -2.3088 -7.5799  8.8677  3.9372 -9.459   5.3013 -1.5698  4.9344
  7.4433]
[ 6.4255 -4.8286 -0.4609 -5.0437 -5.5984  7.2656 -2.3088 -7.5799  8.8677  3.9372 -9.459   5.3013 -1.5698  4.9344
  7.4433]
[ 1.3517  3.3626

In [7]:
# load all digits datasets
import loader
from functools import partial

mnist_dir = './data/digits/mnist'
ch74_dir = './data/digits/chars74k'
dida_dir = './data/digits/dida'
NUM_PER_CLASS = 30
DOWN_SAMPLE = True
N_MNIST = 25000

def visualize_imgs(img_array, rows = 3, cols = 8, selected = None):
    if DOWN_SAMPLE:
        process = lambda x:x.reshape(14,14)
    else:
        process = lambda x:x.reshape(28,28)
    if selected is None:
        assert(img_array.shape[0] >= rows * cols)
        selected = np.random.choice(img_array.shape[0], rows * cols, replace = False)
    else:
        assert(selected.size >= rows * cols)
        
    k = 1
    fid = plt.figure()
    for i in range(rows):
        for j in range(cols):
            plt.subplot(rows, cols, k)
            plt.imshow(process(img_array[selected[k-1]]), cmap='gray' )
            k += 1
            plt.axis('off')
    plt.show()
    
mnist_train, mnist_test = loader.read_mnist(mnist_dir, down_sample = DOWN_SAMPLE, with_label = False)
ch74 = loader.read_chars74k(ch74_dir, NUM_PER_CLASS)
ch74 = np.array(list(map(partial(loader.transform_to_mnist, down_sample = DOWN_SAMPLE, normalize = False) ,ch74)))
dida = loader.read_dida(dida_dir, NUM_PER_CLASS)
dida = np.array(list(map(partial(loader.transform_to_mnist, down_sample = DOWN_SAMPLE, normalize = True) ,dida)))

# sub-sample mnist for train and augment with reversed color 
np.random.seed(3)
sub_train = mnist_train[np.random.choice(np.arange(mnist_train.shape[0]), size = N_MNIST, replace = False)]
mnist_train = np.vstack([sub_train, 1-sub_train])
np.random.shuffle(mnist_train)

In [8]:
# do some test on mnist dataset
import time
data = mnist_train[0:5000, :]
st = time.time()
md1 = MixMG().fit(data, n_comps = 10)
ed = time.time()
print(ed - st)
print(np.percentile(md1.mass(data, logmode = 1), [10,25,50,75,90]))
st = time.time()
lm = MixMGLearner(n_components = 10).fit(data)
md2 = lm.get_model()
ed = time.time()
print(ed - st)
print(np.percentile(md2.mass(data, logmode = 1), [10,25,50,75,90]))

64.34634160995483
[456.0693 487.9154 519.8477 614.0098 652.6449]
100.84434461593628
[460.7315 495.122  530.2283 616.4325 642.0958]


In [9]:
# test if we have some weight on data
weight = np.random.rand(5000)
weight = weight / weight.sum() * 5000

st = time.time()
lm = MixMGLearner(n_components = 10).fit(data, weight)
md3 = lm.get_model()
ed = time.time()
print(ed - st)
print(np.percentile(md3.mass(data, logmode = 1), [10,25,50,75,90]))


99.47815942764282
[443.6595 489.8257 531.9295 610.68   652.9033]


In [10]:
print(np.mean(md1.mass(data, logmode = 1) ))
print(np.mean(md2.mass(data, logmode = 1) ))
print(np.mean(md3.mass(data, logmode = 1) ))


print(np.mean(md1.mass(data, logmode = 1) * weight))
print(np.mean(md2.mass(data, logmode = 1) * weight))
print(np.mean(md3.mass(data, logmode = 1) * weight))

print(np.percentile(md1.mass(data, logmode = 1), [10,25,50,75,90]))
print(np.percentile(md2.mass(data, logmode = 1), [10,25,50,75,90]))
print(np.percentile(md3.mass(data, logmode = 1), [10,25,50,75,90]))

print(np.percentile(md1.mass(data, logmode = 1) * weight, [10,25,50,75,90]))
print(np.percentile(md2.mass(data, logmode = 1) * weight, [10,25,50,75,90]))
print(np.percentile(md3.mass(data, logmode = 1) * weight, [10,25,50,75,90]))


540.7888940204662
545.0785485788617
529.2828656869956
541.0112098229312
545.031702986258
547.8828863405189
[456.0693 487.9154 519.8477 614.0098 652.6449]
[460.7315 495.122  530.2283 616.4325 642.0958]
[443.6595 489.8257 531.9295 610.68   652.9033]
[106.0825 265.5519 534.4411 795.7023 973.9185]
[105.0261 267.9611 539.3853 804.9909 977.7571]
[ 95.879  258.633  541.4898 820.5335 995.3232]
