In [16]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import math

from sklearn.datasets import make_classification
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score

from itertools import combinations

class PrototypicalNetwork:
    def __init__(self) -> None:
        self.X = None
        self.Y = None
        self.best_acc = None
        self.best_comb = None 
    
    def GenerateData(self, n_samples, n_classes, random_state):
        self.n_samples = n_samples
        self.n_classes = n_classes
        self.X, self.Y = make_classification(
                n_samples=self.n_samples, n_features=2, n_informative=2,
                n_redundant=0, n_clusters_per_class=1,
                n_classes=self.n_classes, random_state=random_state
        )
        self.data = pd.DataFrame({"x1":self.X[:,0], "x2":self.X[:,1], "Y":self.Y})

    def PlotData(self):
        fig, ax = plt.subplots(figsize=(20, 10))
        sns.scatterplot(data=self.data, x="x1", y="x2", hue="Y", s=80, ax=ax)
        plt.scatter(self.cluster_center[:,0], self.cluster_center[:,1], marker="x", linewidths=2)
        # plot clustering boundary
        plt.imshow(
                self.Z,
                interpolation="nearest",
                extent=(self.xx.min(), self.xx.max(), self.yy.min(), self.yy.max()),
                cmap=plt.cm.Pastel2,
                aspect="auto",
                origin="lower",
                alpha=0.25
            )

        if self.best_comb:
            # dimension reduction if the number of features is larger than 2
            if self.X.shape[1] > 2:
                X = PCA(n_components=2).fit_transform(self.X)
                plt.scatter(X[self.best_comb, 0], X[self.best_comb, 1], s=100, edgecolors="red", facecolor="none", linewidths=2)
            else:
                # highlight the best batch
                plt.scatter(self.X[self.best_comb, 0], self.X[self.best_comb, 1], s=100, edgecolors="red", facecolor="none", linewidths=2)
    
    def Clsutering(self, random_state):
        kmeans = KMeans(n_clusters=self.n_classes, random_state=random_state).fit(self.X)
        # Step size of the mesh. Decrease to increase the quality of the VQ.
        h = 0.02  # point in the mesh [x_min, x_max]x[y_min, y_max].
        # Plot the decision boundary. For that, we will assign a color to each
        x_min, x_max = self.X[:, 0].min() - 1, self.X[:, 0].max() + 1
        y_min, y_max = self.X[:, 1].min() - 1, self.X[:, 1].max() + 1
        self.xx, self.yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
        # Obtain labels for each point in mesh. Use last trained model.
        Z = kmeans.predict(np.c_[self.xx.ravel(), self.yy.ravel()])
        self.Z = Z.reshape(self.xx.shape)

        # get the clustering centers
        self.cluster_center = kmeans.cluster_centers_
    
    def PNPrediction(self, prototyps, x):
        dist = np.sum((np.array(prototyps)-x)**2, axis=1)
        idx = np.argmin(dist)
        pred = prototyps.index[idx]
        return pred
    def binomial_coef(self):
        n_fac = math.factorial(self.n_samples)
        k_fac = math.factorial(self.n_classes)
        n_minus_k_fac = math.factorial(self.n_samples - self.n_classes)
        return n_fac/(k_fac*n_minus_k_fac)
    def GetBestBatch(self):
        # get the batch combination for PN
        combs = combinations(list(range(self.n_samples)), self.n_classes)
        # simulate PN classification
        best_acc = 0
        for index, comb in enumerate(combs):
            # calculate the prototyps by averaging the labeled data
            prototyps = self.data.loc[comb,:].groupby("Y").mean()
            # prediction
            y_pred = [self.PNPrediction(prototyps, x) for x in self.X]
            y_true = list(self.Y)
            acc = accuracy_score(y_true, y_pred)
            if acc > best_acc:
                best_acc = acc
                best_comb = comb
            if index%2000 == 0:
                print(f"progress: {index}/{self.binomial_coef()}")
        self.best_acc = best_acc
        self.best_comb = best_comb
        return best_acc

In [None]:
PN = PrototypicalNetwork()
PN.GenerateData(n_samples=80, n_classes=4, random_state=125)
PN.Clsutering(random_state=1)
best_acc = PN.GetBestBatch()
print(f"Best Accuracy: {best_acc}")
PN.PlotData()