In [1]:
import numpy as np
import torch
from sklearn import svm
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier, NeighborhoodComponentsAnalysis
from sklearn.pipeline import Pipeline
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text

from src.base.training.models.architectures.lenet import LeNet
from src.base.training.models.architectures.lenet_light import LeNetLight

In [2]:
DATA_PATH = "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\data"

def load_samples(dataset_fn, n_samples, train=True):
    dataset = dataset_fn(
        DATA_PATH,
        train=train,
        download=True,
        transform=ToTensor()
    )
    loader = DataLoader(dataset, batch_size=n_samples)
    return next(iter(loader))

def shuffle(x, y):
    shuffle_index = torch.randperm(x.shape[0])
    return x[shuffle_index], y[shuffle_index]

def load_model(model_fn, path):
    model = model_fn()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

def load_and_prepare(n_samples=100, train=True, model_fn=LeNet):
    fmnist_x, fmnist_y = load_samples(datasets.FashionMNIST, int(n_samples/2), train=train)
    mnist_x, mnist_y = load_samples(datasets.MNIST, int(n_samples/2), train=train)
    # Off setting the fmnist labels
    fmnist_y = fmnist_y + 10
    x, y = shuffle(
        torch.cat((mnist_x, fmnist_x), dim=0),
        torch.cat((mnist_y, fmnist_y), dim=0)
    )
    mnist_cluster = load_model(model_fn, "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\notebooks\\models\\daeclust_15\\5aa285fe2dad84e59107a2652432eeac66db9c709fe2719ba74bd80caa7f493a\\final_model.state")
    fmnist_cluster = load_model(model_fn, "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\notebooks\\models\\daeclust_15\\e5307874a84923007d15c8c019aa67d7756478bd3466d17a14b856a76e6ee29d\\final_model.state")

    out_mnist = mnist_cluster(x)
    out_fmnist = fmnist_cluster(x)
    x_out = torch.cat((out_fmnist, out_mnist), dim=1)
    return x_out.detach().numpy(), y.detach().numpy()

In [3]:
x_train, y_train = load_and_prepare(model_fn=LeNetLight, n_samples=400, train=True)
x_test, y_test = load_and_prepare(model_fn=LeNetLight, n_samples=400, train=False)

In [4]:
np.bincount(y_train)

array([21, 26, 20, 21, 21, 13, 19, 21, 15, 23, 24, 26, 18, 17, 18, 20, 21,
       21, 16, 19], dtype=int64)

In [5]:
pca = PCA(n_components=10)
x_pca_train = pca.fit_transform(x_train)
x_pca_test = pca.transform(x_test)

In [6]:
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=3)
decision_tree = decision_tree.fit(x_train, y_train)

In [14]:
tree_pred = decision_tree.predict(x_test)
classification_report(tree_pred, y_test, output_dict=True)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'0': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 17},
 '1': {'precision': 1.0,
  'recall': 0.9655172413793104,
  'f1-score': 0.9824561403508771,
  'support': 29},
 '2': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '3': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '4': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '5': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '6': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '7': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '8': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '9': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 0},
 '10': {'precision': 1.0,
  'recall': 0.06097560975609756,
  'f1-score': 0.11494252873563218,
  'support': 328},
 '11': {'precision': 0.9629629629629629,
  'recall': 1.0,
  'f1-score': 0.9811320754716981,
  'support': 26},
 '12': {'precision'

In [8]:
# LinearSVC, ovo, ovr
svm_clf = svm.SVC()
svm_clf.fit(x_train, y_train)

In [9]:
svm_pred = svm_clf.predict(x_test)
classification_report(svm_pred, y_test)

'              precision    recall  f1-score   support\n\n           0       1.00      1.00      1.00        17\n           1       1.00      0.97      0.98        29\n           2       0.94      1.00      0.97        15\n           3       0.94      1.00      0.97        15\n           4       0.96      1.00      0.98        27\n           5       0.95      0.95      0.95        20\n           6       1.00      1.00      1.00        20\n           7       0.96      1.00      0.98        23\n           8       1.00      0.67      0.80        15\n           9       0.95      1.00      0.98        20\n          10       0.95      0.90      0.93        21\n          11       0.96      1.00      0.98        26\n          12       0.74      0.67      0.70        30\n          13       0.76      0.93      0.84        14\n          14       0.57      0.80      0.67        15\n          15       0.94      0.79      0.86        19\n          16       0.62      0.56      0.59        18\n       

In [10]:
# https://scikit-learn.org/stable/modules/naive_bayes.html
gnb = GaussianNB()
gnb = gnb.fit(x_train, y_train)
gnb_pred = gnb.predict(x_test)
classification_report(gnb_pred, y_test)

'              precision    recall  f1-score   support\n\n           0       1.00      1.00      1.00        17\n           1       1.00      1.00      1.00        28\n           2       1.00      0.84      0.91        19\n           3       0.81      1.00      0.90        13\n           4       0.96      0.96      0.96        28\n           5       0.95      0.86      0.90        22\n           6       1.00      1.00      1.00        20\n           7       0.96      1.00      0.98        23\n           8       0.90      0.69      0.78        13\n           9       0.95      1.00      0.98        20\n          10       0.95      0.83      0.88        23\n          11       0.96      1.00      0.98        26\n          12       0.74      0.71      0.73        28\n          13       0.76      0.93      0.84        14\n          14       0.52      0.73      0.61        15\n          15       0.94      0.83      0.88        18\n          16       0.69      0.52      0.59        21\n       

In [11]:
# https://scikit-learn.org/stable/modules/neighbors.html
nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(x_train)
distances, indices = nbrs.kneighbors(x_test)

In [12]:
nca = NeighborhoodComponentsAnalysis(random_state=42)
knn = KNeighborsClassifier(n_neighbors=3)
nca_pipe = Pipeline([('nca', nca), ('knn', knn)])
nca_pipe.fit(x_train, y_train)
nca_knn_preds = nca_pipe.predict(x_test)
classification_report(nca_knn_preds, y_test)

'              precision    recall  f1-score   support\n\n           0       1.00      1.00      1.00        17\n           1       1.00      1.00      1.00        28\n           2       0.94      0.94      0.94        16\n           3       0.94      0.94      0.94        16\n           4       0.93      1.00      0.96        26\n           5       0.95      0.90      0.93        21\n           6       1.00      1.00      1.00        20\n           7       0.96      0.96      0.96        24\n           8       1.00      0.62      0.77        16\n           9       0.81      0.94      0.87        18\n          10       0.75      0.83      0.79        18\n          11       0.96      1.00      0.98        26\n          12       0.85      0.72      0.78        32\n          13       0.82      0.88      0.85        16\n          14       0.52      0.73      0.61        15\n          15       0.94      1.00      0.97        15\n          16       0.69      0.58      0.63        19\n       

In [13]:
rnd_forest = RandomForestClassifier(max_depth=3, random_state=0)
rnd_forest.fit(x_train, y_train)
forest_pred = rnd_forest.predict(x_test)
classification_report(forest_pred, y_test)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


'              precision    recall  f1-score   support\n\n           0       1.00      1.00      1.00        17\n           1       1.00      0.97      0.98        29\n           2       1.00      0.94      0.97        17\n           3       0.94      1.00      0.97        15\n           4       1.00      0.93      0.97        30\n           5       0.00      0.00      0.00         0\n           6       0.95      1.00      0.97        19\n           7       0.96      1.00      0.98        23\n           8       0.30      1.00      0.46         3\n           9       1.00      1.00      1.00        21\n          10       1.00      0.22      0.37        89\n          11       0.96      1.00      0.98        26\n          12       0.78      0.75      0.76        28\n          13       0.00      0.00      0.00         0\n          14       0.00      0.00      0.00         0\n          15       0.94      1.00      0.97        15\n          16       0.75      0.67      0.71        18\n       