In [None]:
import torch
import numpy as np
import skfuzzy as fuzz
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from scipy.spatial.distance import cdist

# Load pre-extracted MNIST features and labels
features_path = 'imbalanced_train_features.pt'
labels_path = 'imbalanced_train_labels.pt'
features = torch.load(features_path).numpy()
labels = torch.load(labels_path).numpy()

# Define majority and minority classes
majority_class = 0
minority_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9]
print("Majority class:", majority_class)
print("Minority classes:", minority_indices)

# Step 2: Perform Fuzzy C-Means Clustering
n_clusters = 10
cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(features.T, n_clusters, 2, error=0.005, maxiter=1000, init=None)

# Print centroids
print("Centroids:\n", cntr)

# Plot the data and centroids
plt.scatter(features[:, 0], features[:, 1], c='blue', marker='o', alpha=0.5)
plt.scatter(cntr[:, 0], cntr[:, 1], c='red', marker='x')
plt.title('Fuzzy C-Means Clustering')
plt.show()

# Step 3: Responsibility matrix is initialized (done during FCM)

# Step 4: Initialize GMM Parameters
def initialize_gmm_parameters(data, cntr, n_clusters):
    gmm = GaussianMixture(n_components=n_clusters, means_init=cntr)
    gmm.fit(data)
    return gmm

gmm = initialize_gmm_parameters(features, cntr, n_clusters)
mu = gmm.means_
sigma = gmm.covariances_

print("Initial GMM parameters:")
print("Means:\n", mu)
print("Covariances:\n", sigma)

# Step 5: Update means, covariance, and mixing coefficients (handled internally by GMM fitting)

# Step 6: Select all the 9 minority components (those with lower mixing coefficients)
print("Minority components indices:", minority_indices)

# Step 7: Calculate m_comp and apply GMM for each minority cluster
def apply_gmm_to_minority(data, majority_data, minority_indices, labels):
    gmm_params = {}
    for index in minority_indices:
        minority_data = data[labels == index]
        if len(minority_data) < 2:  # Check if there are at least two samples
            continue
        m_comp = len(majority_data) / len(minority_data)
        n_components = min(len(minority_data), max(1, int(np.round(m_comp))))  # Ensure n_components is at least 1 and <= number of samples
        gmm = GaussianMixture(n_components=n_components, init_params='kmeans')
        gmm.fit(minority_data)
        gmm_params[index] = (gmm.means_, gmm.covariances_, m_comp)
    return gmm_params

majority_class_indices = labels == majority_class
majority_data = features[majority_class_indices]
gmm_params = apply_gmm_to_minority(features, majority_data, minority_indices, labels)

# Step 8: Find k nearest neighbors among means for each element in minority classes
def find_k_nearest_neighbours(data, means, k):
    distances = cdist(data, means)
    k_nearest_indices = np.argsort(distances, axis=1)[:, :k]
    k_nearest_means = means[k_nearest_indices]
    return k_nearest_means

k = 3

# Step 9: Calculate mux and sigmax, and generate new elements
def generate_new_elements(minority_data, k_nearest_means, sigma, k, n_c):
    new_elements = []
    for x in minority_data:
        k_nearest_means = k_nearest_means.reshape(-1, k_nearest_means.shape[-1])  # Ensure k_nearest_means is 2D
        mux = (np.sum(k_nearest_means, axis=0) + x) / (k + 1)
        sigmax = (np.sum(sigma, axis=0) + np.cov(minority_data.T)) / (k + 1)
        new_elements.append(np.random.multivariate_normal(mux, sigmax, n_c))
    return np.vstack(new_elements)

new_elements = []

for index in minority_indices:
    minority_data = features[labels == index]
    if len(minority_data) < 2:  # Check if there are at least two samples
        continue
    means, covariances, m_comp = gmm_params[index]
    if m_comp >= k:
        k_nearest_means = find_k_nearest_neighbours(minority_data, means, k)
        n_c = max(1, (len(majority_data) - len(minority_data)) // len(minority_data))
        new_elements_class = generate_new_elements(minority_data, k_nearest_means, covariances, k, n_c)
    else:
        n_comp = min(len(minority_data), max(1, len(majority_data) // len(minority_data)))  # Ensure n_comp is at least 1 and <= number of samples
        gmm = GaussianMixture(n_components=n_comp, init_params='kmeans')
        gmm.fit(minority_data)
        means, covariances = gmm.means_, gmm.covariances_
        new_elements_class = []
        n_c = max(1, len(minority_data) * (len(majority_data) // n_comp) // len(minority_data))  # Ensure n_c is at least 1
        for mean, cov in zip(means, covariances):
            new_elements_class.append(np.random.multivariate_normal(mean, cov, n_c))
        new_elements_class = np.vstack(new_elements_class)

    new_elements.append(new_elements_class)

new_elements = np.vstack(new_elements)

# Plotting original data and new generated elements
plt.scatter(features[:, 0], features[:, 1], c='blue', marker='o', alpha=0.5, label='Original Data')
plt.scatter(new_elements[:, 0], new_elements[:, 1], c='green', marker='s', alpha=0.5, label='Generated Data')
plt.title('Original and Generated Data Points')
plt.legend()
plt.show()
