In [None]:
import numpy as np
import torch
from torchvision import datasets
import sklearn.metrics as sm
import matplotlib.pyplot as plt
import tensorflow as tf
import math

In [None]:
class GDAClf:
    def __init__(self):
        (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.mnist.load_data()
        self.mu = np.zeros((10,784,1))
        self.sigma = np.zeros((10,784,784))
        self.sigmainv = np.zeros_like(self.sigma)
        self.logdets = np.zeros((10))
        self.lamda = 1
        self.compile()
   
    def init_lamda(self , lamda=1):
        self.lamda = lamda
        self.compile()


    def compile(self):
        for i in range(10):
            temp_x_train = self.x_train[self.y_train == i,:].reshape(-1,784)
 
            temp_mu = np.mean(temp_x_train , axis = 0).reshape(784,1)
            temp_cov = np.cov(temp_x_train.T).reshape(784,784)

            self.mu[i] = temp_mu

            self.sigma[i] = temp_cov + self.lamda*np.eye(784)

            self.sigmainv[i] = np.linalg.inv(self.sigma[i])
            sign , det = np.linalg.slogdet(self.sigma[i])
            self.logdets[i] = sign * det

    def display_means(self):
        for i in range(10):
            plt.subplot(10,1,i+1)
            plt.imshow(self.mu[i].reshape(28,28))
    
    def plot_lamda_curve(self):
        lamdas = np.linspace(100000 , 0.001 , 5)
        acc = []
        for lamda in lamdas:
            self.init_lamda(lamda)
            acc.append(self.predict())
        plt.plot(lamdas , acc)
        plt.xlabel('Lambda')
        plt.ylabel('Accuracy')
        return lamdas , acc

    def predict(self):
        predictions = []
        for x in self.x_test:
            temp_LL = []
            x_vec = x.reshape(784,1)
            for j in range(10):
                temp_LL.append(compute_LL(x_vec,self.mu[j],self.sigmainv[j],self.sigma[j] , self.logdets[j]))
            predictions.append(np.argmax(temp_LL))
        return sm.accuracy_score(self.y_test , predictions)

def compute_LL(x_vec , mu , sigmainv , sigma, logdet):
    t1 = -( 0.5 * logdet + (784/2) * np.log(2*math.pi) )
    t2 = -0.5 * (x_vec-mu).T @ sigmainv @ (x_vec-mu)
    LL = t1+t2
    return LL


In [None]:
a = GDAClf()

In [None]:
a.display_means()

In [None]:

Acc = a.predict()
print(f'The error rate with lamda = 1 is {1-Acc}')

In [None]:

lamdas , acc = a.plot_lamda_curve()

In [None]:
 
print(f'The best error rate is {min(1-np.array(acc))}')

In [None]:

print(f'The best lambda value is {lamdas[np.argmax(acc)]}')