In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import warnings
import matplotlib.pyplot as plt
import random
warnings.filterwarnings('ignore')

In [None]:
def train(data, label, lr, n_class, shift):
    from collections import Counter
    weights = np.zeros([n_class,n_kc])#*10.0
    kcs = np.array([get_KC(np.array(data.iloc[i,:]),R,thresh) for i in range(data.shape[0])])
    distinct_label = np.sort(list(Counter(label).keys()))
    if len(distinct_label) != n_class:
        raise ValueError('Number of classes to be trained does not match!!!')
    for i in range(n_class):
        this_class = kcs[np.array(label==distinct_label[i])]
        for j in range(len(this_class)):
            out = np.matmul(weights, this_class[j])
            predict = np.random.choice(np.flatnonzero(out == out.max()))
            if not (distinct_label[i]-shift == predict):
                weights[distinct_label[i]-shift] += lr*this_class[j]
                #weights[predict] -= lr*this_class[j]
            else:
                weights[predict] += lr*this_class[j]
            #weights[weights<0] = 0
            weights[weights>1] = 1
    return weights

In [None]:
mnist_train = pd.read_csv('./processed_data/mnist_kmnist_train.csv',header=None)
mnist_train_label = pd.read_csv('./processed_data/mnist_kmnist_train_label.csv',header=None)
mnist_train_label = np.array([x[0] for x in np.array(mnist_train_label)])
fmnist_train = pd.read_csv('./processed_data/fmnist_kmnist_train.csv',header=None)
fmnist_train_label = pd.read_csv('./processed_data/fmnist_kmnist_train_label.csv',header=None)
fmnist_train_label = np.array([x[0] for x in np.array(fmnist_train_label)])

In [None]:
mnist_test = pd.read_csv('./processed_data/mnist_kmnist_test.csv',header=None)
mnist_test_label = pd.read_csv('./processed_data/mnist_kmnist_test_label.csv',header=None)
mnist_test_label = np.array([x[0] for x in np.array(mnist_test_label)])
fmnist_test = pd.read_csv('./processed_data/fmnist_kmnist_test.csv',header=None)
fmnist_test_label = pd.read_csv('./processed_data/fmnist_kmnist_test_label.csv',header=None)
fmnist_test_label = np.array([x[0] for x in np.array(fmnist_test_label)])

In [None]:
train_data = pd.concat([mnist_train, fmnist_train],axis=0,ignore_index=True)
test_data = pd.concat([mnist_test, fmnist_test],axis=0,ignore_index=True)
train_label = np.concatenate([mnist_train_label, fmnist_train_label+10])
test_label = np.concatenate([mnist_test_label, fmnist_test_label+10])

In [None]:
from sklearn.preprocessing import minmax_scale
train_data = pd.DataFrame(minmax_scale(train_data,axis=1))
test_data = pd.DataFrame(minmax_scale(test_data,axis=1))

In [None]:
def generate_transformation_matrix(n_kc,n_orn,n_response):
    R = np.zeros((n_kc, n_orn))
    for i in range(n_kc):
        random.seed(i)
        R[i,random.sample(list(range(n_orn)), n_response)] = 1
    return R

def get_KC(p,R,thresh):
    '''
    odor: a vector of ORN responses for a given odor
    w: inhibitory synaptic strength from LN to PN
    R: random linear transformation matrix from PN to KC
    thresh: rectlinear threshold for KC activation
    '''
    KC = np.matmul(R,p)
    KC[KC<=thresh] = 0
    threshold = np.quantile(KC,0.95)
    KC[KC<threshold] = 0
    KC = KC/np.max(KC)
    return KC

In [None]:
def split_data(data, labels, split):
    from collections import Counter
    distinct_labels = list(Counter(labels).keys())
    n_labels = len(distinct_labels)
    n_split = int(n_labels/split)
    trans = data.T
    trans.columns = labels
    datasets = {}
    datalabels = {}
    for i in range(n_split):
        cond1 = trans.columns.values >= i*split
        cond2 = trans.columns.values < (i+1)*split
        out = trans.iloc[:,cond1&cond2]
        out_label = out.columns.values
        datasets[i] = out.T
        datalabels[i] = out_label
    return (datasets, datalabels)

In [None]:
def accu(weights, data, label):
    kcs = np.array([get_KC(np.array(data.iloc[i,:]),R,thresh) for i in range(data.shape[0])])
    result = np.matmul(weights, kcs.T)
    pred = np.argmax(result, axis=0)
    return np.sum(pred==label)/len(label)

In [None]:
n_kc = 3200
n_orn = 84
n_response = 10
R = generate_transformation_matrix(n_kc,n_orn,n_response)
thresh = 0

In [None]:
train_datasets, train_labels = split_data(train_data, train_label, 2)
test_datasets, test_labels = split_data(test_data, test_label, 2)

In [None]:
n_task = 10
lrs = [0.1, 0.01, 0.001, 0.0001]
for lr in lrs:
    accuracy = np.zeros([n_task, n_task])
    trained_weights = {}
    for i in range(n_task):
        trained_weights[i] = train(train_datasets[i], train_labels[i], lr, 2, i*2)

        weights_to_test = np.concatenate([trained_weights[j] for j in range(i+1)])

        for j in range(i+1):
            accuracy[i,j] = accu(weights_to_test,test_datasets[j],test_labels[j])

    accuracy = pd.DataFrame(accuracy)
    accuracy.to_csv('./accuracy/Sparse-coding_v4_lr'+str(lr)+'.csv',index=False,header=False)
    
    print('learning rate '+str(lr)+' done!')