In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.nn import Linear
from sklearn.cluster import KMeans
"""
原文链接：IDEC,https://www.ijcai.org/proceedings/2017/0243.pdf
Pytorch版复现代码链接：https://github.com/dawnranger/IDEC-pytorch
"""

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
def RI_metric(pred, target):
    # RI
    n = len(target)
    TP = 0
    TN = 0
    for i in range(n - 1):
        for j in range(i + 1, n):
            if target[i] != target[j]:
                if pred[i] != pred[j]:
                    TN += 1
            else:
                if pred[i] == pred[j]:
                    TP += 1

    RI = n * (n - 1) / 2
    RI = (TP + TN) / RI
    return RI


def cluster_acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    # from sklearn.utils.linear_assignment_ import linear_assignment
    from scipy.optimize import linear_sum_assignment as linear_assignment
    ind = linear_assignment(w.max() - w)
    ind = np.asarray(ind)
    ind = np.transpose(ind)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size


def AccRIkeans(data, num_class, real_label) -> float:
    data = data.copy()
    data[np.isnan(data)] = 0
    kmeans = KMeans(n_clusters=num_class, n_init=20)
    kmeans.fit(data)
    label_bank = kmeans.labels_
    RI_kmeans = RI_metric(pred=label_bank, target=real_label)
    acc_kmeans = cluster_acc(y_true=real_label, y_pred=label_bank)
    return round(RI_kmeans, 4), round(acc_kmeans, 4)

In [4]:
def next_batch(samples_num, batch = 256):
    batch_num = int(samples_num / batch)
    left_row = samples_num - batch_num * batch

    for i in range(batch_num):
        yield np.arange(i * batch, (i + 1) * batch)

    if left_row != 0:
        yield np.arange(samples_num - left_row, samples_num)

In [5]:
class AE(nn.Module):

    def __init__(self, n_enc_1, n_enc_2, n_enc_3, n_dec_1, n_dec_2, n_dec_3,
                 n_input, n_z):
        super(AE, self).__init__()

        # encoder
        self.enc_1 = Linear(n_input, n_enc_1)
        self.enc_2 = Linear(n_enc_1, n_enc_2)
        self.enc_3 = Linear(n_enc_2, n_enc_3)

        self.z_layer = Linear(n_enc_3, n_z)

        # decoder
        self.dec_1 = Linear(n_z, n_dec_1)
        self.dec_2 = Linear(n_dec_1, n_dec_2)
        self.dec_3 = Linear(n_dec_2, n_dec_3)

        self.x_bar_layer = Linear(n_dec_3, n_input)

    def forward(self, x):
        # encoder
        enc_h1 = F.relu(self.enc_1(x))
        enc_h2 = F.relu(self.enc_2(enc_h1))
        enc_h3 = F.relu(self.enc_3(enc_h2))

        z = self.z_layer(enc_h3)

        # decoder
        dec_h1 = F.relu(self.dec_1(z))
        dec_h2 = F.relu(self.dec_2(dec_h1))
        dec_h3 = F.relu(self.dec_3(dec_h2))
        x_bar = self.x_bar_layer(dec_h3)

        return x_bar, z

In [6]:
def model_ae_cluster(X, Y):
    print("Start: ", X.shape, Y.shape)
    for i in range(10):
        temp = np.where(Y == i)[0]
        print("The number of ", i, " = ", len(temp))
    feature_dim = 28 * 28
    model = AE(
        n_enc_1=500,
        n_enc_2=500,
        n_enc_3=2000,
        n_dec_1=2000,
        n_dec_2=500,
        n_dec_3=500,
        n_input=feature_dim,
        n_z=10)
    print(model)
    model= model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    ae_representation = []
    max_acc = 0.0
    max_ri = 0.0

    for epoch in range(201):
        total_loss = 0.0
        count_batch = 0
        representation = []
        for idxs in next_batch(X.shape[0]):
            x1 = torch.tensor(X[idxs]).clone()
            x1 = x1.to(device)
            optimizer.zero_grad()
            x_bar, z = model(x1)
            loss = F.mse_loss(x_bar, x1)
            total_loss = total_loss + loss.item()
            loss.backward()
            optimizer.step()
            count_batch = count_batch + 1
            representation.append(z.data.cpu().numpy())

        representation = np.concatenate(representation)
        ae_representation = representation
        if epoch % 1 == 0:
            print("epoch {} loss={:.4f}".format(epoch, total_loss / count_batch))
        if epoch % 10 == 0:
            representation_RI, representation_Acc = AccRIkeans(
                data=ae_representation, num_class=10, real_label=Y)
            max_acc = max(representation_Acc, max_acc)
            max_ri = max(representation_RI, max_ri)
            print("epoch = ", epoch,
                  ", representation RI_kmeans = ", representation_RI, ", representation_Acc_kmeans = ", representation_Acc)

    print("*"*30)
    print("max_acc = ", max_acc)
    print("max_ri = ", max_ri)
    print("End training of AE.")
    temp_X = torch.tensor(X).to(device)
    _, hidden_x = model(temp_X)
    hidden_RI, hidden_Acc = AccRIkeans(data=hidden_x.data.cpu().numpy(), num_class=10, real_label=Y)
    print("hidden_RI = ", hidden_RI, ", hidden_Acc = ", hidden_Acc)
    Row_RI, Row_Acc = AccRIkeans(data=X, num_class=10, real_label=Y)
    print("Row_RI = ", Row_RI, ", Row_Acc = ", Row_Acc)
    return model

In [7]:
def load_mnist(path='./dataset/mnist.npz'):
    f = np.load(path)

    x_train, y_train, x_test, y_test = f['x_train'], f['y_train'], f[
        'x_test'], f['y_test']
    f.close()
    x = np.concatenate((x_train, x_test))
    y = np.concatenate((y_train, y_test)).astype(np.int32)
    x = x.reshape((x.shape[0], -1)).astype(np.float32)
    x = np.divide(x, 255.)
    print('MNIST samples', x.shape)
    return x, y

In [8]:
x, y = load_mnist()
print(x.shape, y.shape)

MNIST samples (70000, 784)
(70000, 784) (70000,)


In [9]:
model_all = model_ae_cluster(X=x, Y=y)

Start:  (70000, 784) (70000,)
The number of  0  =  6903
The number of  1  =  7877
The number of  2  =  6990
The number of  3  =  7141
The number of  4  =  6824
The number of  5  =  6313
The number of  6  =  6876
The number of  7  =  7293
The number of  8  =  6825
The number of  9  =  6958
AE(
  (enc_1): Linear(in_features=784, out_features=500, bias=True)
  (enc_2): Linear(in_features=500, out_features=500, bias=True)
  (enc_3): Linear(in_features=500, out_features=2000, bias=True)
  (z_layer): Linear(in_features=2000, out_features=10, bias=True)
  (dec_1): Linear(in_features=10, out_features=2000, bias=True)
  (dec_2): Linear(in_features=2000, out_features=500, bias=True)
  (dec_3): Linear(in_features=500, out_features=500, bias=True)
  (x_bar_layer): Linear(in_features=500, out_features=784, bias=True)
)
epoch 0 loss=0.0561
epoch =  0 , representation RI_kmeans =  0.7949 , representation_Acc_kmeans =  0.2782
epoch 1 loss=0.0329
epoch 2 loss=0.0253
epoch 3 loss=0.0222
epoch 4 loss=0.0

In [10]:
temp_X = torch.tensor(x[:6000]).to(device)
_, hidden_x = model_all(temp_X)
hidden_RI, hidden_Acc = AccRIkeans(data=hidden_x.data.cpu().numpy(), num_class=10, real_label=y[:6000])
print("hidden_RI = ", hidden_RI, ", hidden_Acc = ", hidden_Acc)

hidden_RI =  0.9327 , hidden_Acc =  0.7843


In [11]:
model_half = model_ae_cluster(X=x[:30000], Y=y[:30000])

Start:  (30000, 784) (30000,)
The number of  0  =  2961
The number of  1  =  3423
The number of  2  =  2948
The number of  3  =  3073
The number of  4  =  2926
The number of  5  =  2709
The number of  6  =  2975
The number of  7  =  3107
The number of  8  =  2875
The number of  9  =  3003
AE(
  (enc_1): Linear(in_features=784, out_features=500, bias=True)
  (enc_2): Linear(in_features=500, out_features=500, bias=True)
  (enc_3): Linear(in_features=500, out_features=2000, bias=True)
  (z_layer): Linear(in_features=2000, out_features=10, bias=True)
  (dec_1): Linear(in_features=10, out_features=2000, bias=True)
  (dec_2): Linear(in_features=2000, out_features=500, bias=True)
  (dec_3): Linear(in_features=500, out_features=500, bias=True)
  (x_bar_layer): Linear(in_features=500, out_features=784, bias=True)
)
epoch 0 loss=0.0643
epoch =  0 , representation RI_kmeans =  0.7787 , representation_Acc_kmeans =  0.1734
epoch 1 loss=0.0516
epoch 2 loss=0.0390
epoch 3 loss=0.0314
epoch 4 loss=0.0

In [12]:
temp_X = torch.tensor(x[:6000]).to(device)
_, hidden_x = model_half(temp_X)
hidden_RI, hidden_Acc = AccRIkeans(data=hidden_x.data.cpu().numpy(), num_class=10, real_label=y[:6000])
print("hidden_RI = ", hidden_RI, ", hidden_Acc = ", hidden_Acc)

hidden_RI =  0.8918 , hidden_Acc =  0.6107


In [13]:
model_part = model_ae_cluster(X=x[:12000], Y=y[:12000])

Start:  (12000, 784) (12000,)
The number of  0  =  1206
The number of  1  =  1351
The number of  2  =  1176
The number of  3  =  1228
The number of  4  =  1184
The number of  5  =  1048
The number of  6  =  1208
The number of  7  =  1279
The number of  8  =  1127
The number of  9  =  1193
AE(
  (enc_1): Linear(in_features=784, out_features=500, bias=True)
  (enc_2): Linear(in_features=500, out_features=500, bias=True)
  (enc_3): Linear(in_features=500, out_features=2000, bias=True)
  (z_layer): Linear(in_features=2000, out_features=10, bias=True)
  (dec_1): Linear(in_features=10, out_features=2000, bias=True)
  (dec_2): Linear(in_features=2000, out_features=500, bias=True)
  (dec_3): Linear(in_features=500, out_features=500, bias=True)
  (x_bar_layer): Linear(in_features=500, out_features=784, bias=True)
)
epoch 0 loss=0.0707
epoch =  0 , representation RI_kmeans =  0.7138 , representation_Acc_kmeans =  0.155
epoch 1 loss=0.0617
epoch 2 loss=0.0568
epoch 3 loss=0.0510
epoch 4 loss=0.04

In [14]:
temp_X = torch.tensor(x[:6000]).to(device)
_, hidden_x = model_part(temp_X)
hidden_RI, hidden_Acc = AccRIkeans(data=hidden_x.data.cpu().numpy(), num_class=10, real_label=y[:6000])
print("hidden_RI = ", hidden_RI, ", hidden_Acc = ", hidden_Acc)

hidden_RI =  0.8659 , hidden_Acc =  0.5182
