# Gaussian Mixture Model(GMM)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import cross_validate
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import KFold

In [None]:
import pandas as pd
csv = pd.read_csv('C:/Users/user/Desktop/project/data_pca.csv',encoding = 'euc-kr')
df = pd.DataFrame(csv)

## 학습/테스트 데이터 분할

In [None]:
X = df.iloc[:, 1:-1]
indices = KFold(n_splits = 5)
for train_index, test_index in indices.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    Y_train, Y_test = X.iloc[train_index], X.iloc[test_index]

## 클래스 수 추출

In [None]:
num_classes = len(np.unique(Y_train))

## 가우시안 혼합 모델

In [None]:
classifier = GaussianMixture(n_components = 3, covariance_type = "full", reg_covar = 1e-5)
classifier.means_ = np.array([X_train[Y_train == i].mean(axis = 0) for i in range(num_classes)])
classifier.fit(X)
gmm_cluster_labels = classifier.predict(X)
probs = classifier.predict_proba(X)
X['gmm_cluster'] = gmm_cluster_labels

## PLT

In [None]:
import matplotlib
from matplotlib import font_manager, rc
import platform
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
if platform.system() == 'Windows':
    font_name = font_manager.FontProperties(fname="c:/Windows/Fonts/malgun.ttf").get_name()
    rc('font', family=font_name)
else:    
    rc('font', family='AppleGothic')
matplotlib.rcParams['axes.unicode_minus'] = False   

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components = 2)
pca_transformed = pca.fit_transform(X)
x = pca_transformed[:,0]
y = pca_transformed[:,1]
annotations = df.iloc[:,0]
plt.figure(figsize=(15, 10))
plt.grid()
plt.scatter(x, y, c = gmm_cluster_labels, s =120, cmap = 'viridis')
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i], y[i]), fontsize=20)
plt.xlabel('Dim 1')
plt.ylabel('Dim 2')
plt.title('GMM', fontsize = 20)
plt.show()