In [None]:
class LDA:
  def __init__(self, n_dimensions=None):
    self.n_dimensions_ = n_dimensions
    self.linear_discriminants_ = []
    self.explained_variances_ = None
    self.classes = None

  def scatter_matrix_within_class(self, X, Y):
    classes = np.unique(Y)
    
    Sw = 0            # Sw = Within-class scatter matrix.
    for i in classes:
      X_cat = X[Y == i]
      mean = np.mean(X_cat, axis=0)
      Sw += np.dot((X_cat-mean).T, (X_cat-mean))

    return Sw

  def scatter_matrix_between_class(self, X, Y):
    self.classes = np.unique(Y)

    mean_overall = np.mean(X, axis=0) # overall mean
    mean_overall = mean_overall.reshape(1, -1)

    Sb = 0             # Sw = Between-class scatter matrix.
    for i in self.classes:
      X_cat = X[Y == i]
      number_of_instances = X_cat.shape[0]
      mean = np.mean(X_cat, axis=0)
      mean = mean.reshape(1, -1)
      Sb += number_of_instances*(np.dot((mean - mean_overall).T, (mean - mean_overall)))

    return Sb

  def fit(self, X, Y, variance_threshold=None):
    Sw = self.scatter_matrix_within_class(X, Y) # within class scatter matrix
    Sb = self.scatter_matrix_between_class(X, Y) # between class scatter matrix

    mat = np.dot(np.linalg.inv(Sw), Sb) 
    
    eigenvalues, eigenvectors = eig(mat) # eigen values and vectors
    eigenvectors = eigenvectors.T 

    index = np.argsort(eigenvalues)[::-1] # sorting in descending order the eigen vectors according to the eigen values.
    eigenvalues = eigenvalues[index]
    eigenvectors = eigenvectors[index]

    self.explained_variances_ = np.cumsum(eigenvalues) / np.sum(eigenvalues) # 

    if self.n_dimensions_ is None: # if the number of dimensions are not specified, it would take that number of dimensions which has variance greater than specified variance threshold.
      self.n_dimensions_ = np.argmax(np.real(self.explained_variances_) >= variance_threshold) + 1

    self.linear_discriminants_ = np.real(eigenvectors[: self.n_dimensions_]) # complex part is anyway zero so taking only real part.

  def transform(self, X):
    return np.dot(X, (self.linear_discriminants_).T) # transformation of the points.

  def predict(self, X):
    X_lda = self.transform(X) # transform the data points to the LDA subspace

    y_pred = np.zeros(X.shape[0], dtype=self.classes.dtype) # initialize an array to store the predicted classes

    for i, x_lda in enumerate(X_lda):
      distances = [np.linalg.norm(x_lda - np.mean(self.transform(X[Y == c]), axis=0)) for c in self.classes] # compute the distances between the LDA-transformed data point and the class means
      y_pred[i] = self.classes[np.argmin(distances)] # assign the class of the nearest mean to the data point

    return y_pred