In [10]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf

In [2]:
conf_file = OmegaConf.load('./config.yaml')

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# For Small Dataset

In [4]:
from dataset import cls_small_data as Cdata
import model.cls_model as Cmodel

Supported small dataset for classification:  
'zebra',
'zebra_special',
'bal',
'digits',
'iris',
'wine',
'breast_cancer'

In [7]:
dataset_name = 'iris'

criterion = torch.nn.CrossEntropyLoss()

In [8]:
cfg = conf_file['dataset'][dataset_name]
Xs, ys = Cdata.Cls_small_data(dataset_name)

## Classification with NNKNN

In [None]:
# This section is used to reload the imported module. For example, if you made any changes in the model.cls_model, you should run importlib.reload(Cmodel) as long as you set import model.cls_model as Cmodel.
# import importlib
# importlib.reload(Cmodel)

In [13]:
def train_cls(X_train,y_train, X_test, y_test, cfg:DictConfig):
  X_train = X_train.to(device)
  y_train = y_train.to(device)
  X_test = X_test.to(device)

  train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_train, y_train), batch_size=cfg.batch_size, shuffle=True)

  # Train model
  model = Cmodel.NN_k_NN(X_train,
                         y_train,
                         cfg.ca_weight_sharing,
                         cfg.top_case_enabled,
                         cfg.top_k,
                         cfg.discount,
                         device=device)
  
  optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) #, weight_decay=1e-5)

  patience_counter = 0
  for epoch in range(cfg.training_epochs):
    epoch_msg = True
    
    for X_train_batch, y_train_batch in train_loader:
      model.train()
      _, _, output, predicted_class = model(X_train_batch)
      loss = criterion(output, y_train_batch)

      # Backward and optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      if epoch_msg and (epoch + 1) % 2 == 0:
        print(f'Epoch [{epoch + 1}/{cfg.training_epochs}], Loss: {loss.item():.4f}')

        epoch_msg = False
      # print("evaluating")
    model.eval()
    with torch.no_grad():
      _, _, output, predicted_class = model(X_test)

      # Calculate accuracy
      accuracy_temp = accuracy_score(y_test.numpy(), predicted_class.cpu().numpy())
    if epoch == 0:
      best_accuracy = accuracy_temp
      torch.save(model.state_dict(), cfg.PATH)
      
    elif accuracy_temp > best_accuracy:
      #memorize best model
      torch.save(model.state_dict(), cfg.PATH)
      best_accuracy = accuracy_temp
      patience_counter = 0

    elif patience_counter > cfg.patience:
      model.eval()
      print("patience exceeded, loading best model")
      break
    else:
      patience_counter += 1

  X_train = X_train.cpu()
  y_train = y_train.cpu()
  X_test = X_test.cpu()
  ##compare with a normal k-nn
  knn =  KNeighborsClassifier(n_neighbors=cfg.top_k)
  knn.fit(X_train, y_train)
  knn_acc  = accuracy_score(knn.predict(X_test), y_test)
  return best_accuracy, knn_acc

In [14]:
accuracies = []
knn_accuracies = []
PATH = f'checkpoints/classifier_{dataset_name}.h5'
cfg.PATH = PATH
k_fold = KFold(n_splits=10, shuffle=True, random_state = None)

for train_index, test_index in k_fold.split(Xs):
  # Get training and testing data
  X_train, X_test = Xs[train_index], Xs[test_index]
  y_train, y_test = ys[train_index], ys[test_index]
  best_accuracy, knn_acc = train_cls(X_train,y_train, X_test, y_test, cfg)
  accuracies.append(best_accuracy)
  knn_accuracies.append(knn_acc)

print(f"Average accuracy:{np.mean(accuracies):.3f}")
print(f"KNN accuracy:{np.mean(knn_accuracies):.3f}")

Epoch [2/1000], Loss: 0.3380
Epoch [4/1000], Loss: 0.2282
Epoch [6/1000], Loss: 0.1642
Epoch [8/1000], Loss: 0.2054
Epoch [10/1000], Loss: 0.1512
Epoch [12/1000], Loss: 0.1841
Epoch [14/1000], Loss: 0.1446
Epoch [16/1000], Loss: 0.1176
Epoch [18/1000], Loss: 0.1426
Epoch [20/1000], Loss: 0.0232
Epoch [22/1000], Loss: 0.1193
Epoch [24/1000], Loss: 0.1217
Epoch [26/1000], Loss: 0.1078
Epoch [28/1000], Loss: 0.1094
Epoch [30/1000], Loss: 0.1156
Epoch [32/1000], Loss: 0.0281
Epoch [34/1000], Loss: 0.1048
Epoch [36/1000], Loss: 0.1053
Epoch [38/1000], Loss: 0.0950
Epoch [40/1000], Loss: 0.0510
Epoch [42/1000], Loss: 0.0247
Epoch [44/1000], Loss: 0.0649
patience exceeded, loading best model
Epoch [2/1000], Loss: 0.3440
Epoch [4/1000], Loss: 0.2220
Epoch [6/1000], Loss: 0.2854
Epoch [8/1000], Loss: 0.1732
Epoch [10/1000], Loss: 0.2715
Epoch [12/1000], Loss: 0.1484
Epoch [14/1000], Loss: 0.0790
Epoch [16/1000], Loss: 0.1540
Epoch [18/1000], Loss: 0.1062
Epoch [20/1000], Loss: 0.1557
Epoch [22/