In [1]:
# Import libraries
from torch import nn,cuda,backends, optim, save, load, no_grad
import torch
from torchvision.datasets import ImageFolder,CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import time
from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
import numpy as np
import sys
import pickle

In [2]:
# Image transformations
mean=[0.4914, 0.4822, 0.4465]
std=[0.2023, 0.1994, 0.2010]

#mean=[0.0178, 0.0574, 0.2181]
#std=[1.0364, 1.0332, 1.0443]

transformations=transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomInvert(p=0.2),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Normalize(mean=mean,std=std),
    ])

In [3]:
batch_size=50
# select classes you want to include in your subset
classes = torch.tensor([3, 8])
classes_names= ["cat","ship"]

data_train = CIFAR10(root='./data', train=True, download=True, transform=transformations)
indices = (torch.tensor(data_train.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]
data_train = torch.utils.data.Subset(data_train, indices)
#classes_names=data_train.classes
dataloader_train = torch.utils.data.DataLoader(data_train,batch_size=batch_size, shuffle=True, num_workers=2)

data_test = CIFAR10(root='./data', train=False, download=True, transform=transformations)
indices = (torch.tensor(data_test.targets)[..., None] == classes).any(-1).nonzero(as_tuple=True)[0]
data_test = torch.utils.data.Subset(data_test, indices)
dataloader_test = torch.utils.data.DataLoader(data_test, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Get device for pytorch to run on
device = (
    "cuda"
    if cuda.is_available()
    else "mps"
    if backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cuda device


In [5]:
knn_1=KNeighborsClassifier(n_neighbors=1)
knn_3=KNeighborsClassifier(n_neighbors=3)
nc=NearestCentroid()

In [6]:
x_train=[]
y_train=[]
for x_batch,y_batch in dataloader_train:
    for i in range(x_batch.shape[0]):
        x_train.append(x_batch[i].detach().numpy().flatten())
        y_train.append(y_batch[i].detach().numpy().flatten())
    
print((len(x_train),x_train[0].shape))

(10000, (3072,))


In [7]:
len(x_train)

10000

In [8]:
x_test=[]
y_test=[]
for x_batch,y_batch in dataloader_test:
    for i in range(x_batch.shape[0]):
        x_test.append(x_batch[i].detach().numpy().flatten())
        y_test.append(y_batch[i].detach().numpy().flatten())
    
print((len(x_test),x_test[0].shape))

(2000, (3072,))


In [9]:
start_training_knn_1=time.time()
knn_1.fit(x_train,np.ravel(y_train))
stop_training_knn_1=time.time()

start_testing_knn_1=time.time()
acc_knn_1=knn_1.score(x_test, y_test)
stop_testing_knn_1=time.time()

p = pickle.dumps(knn_1)
memory_knn_1=sys.getsizeof(p)
del p

In [10]:
start_training_knn_3=time.time()
knn_3.fit(x_train,np.ravel(y_train))
stop_training_knn_3=time.time()

start_testing_knn_3=time.time()
acc_knn_3=knn_3.score(x_test, y_test)
stop_testing_knn_3=time.time()

p = pickle.dumps(knn_3)
memory_knn_3=sys.getsizeof(p)
del p

In [11]:
start_training_nc=time.time()
nc.fit(x_train,np.ravel(y_train))
stop_training_nc=time.time()

start_testing_nc=time.time()
acc_nc=nc.score(x_test, y_test)
stop_testing_nc=time.time()

p = pickle.dumps(nc)
memory_nc=sys.getsizeof(p)
del p

In [12]:
# Print stats
print(f"Classifier: Training time | Testing time | Accuracy | Memory used while idle (Bytes)")
print(f"KNN(K=1): {stop_training_knn_1-start_training_knn_1} | {stop_testing_knn_1-start_testing_knn_1} | {acc_knn_1} | {memory_knn_1}")
print(f"KNN(K=3): {stop_training_knn_3-start_training_knn_3} | {stop_testing_knn_3-start_testing_knn_3} | {acc_knn_3} | {memory_knn_3}")
print(f"NC: {stop_training_nc-start_training_nc} | {stop_testing_nc-start_testing_nc} | {acc_nc} | {memory_nc}")

Classifier: Training time | Testing time | Accuracy | Memory used while idle (Bytes)
KNN(K=1): 0.15329933166503906 | 1.061927080154419 | 0.747 | 122960721
KNN(K=3): 0.05618762969970703 | 0.706660270690918 | 0.761 | 122960721
NC: 0.10158920288085938 | 0.029914379119873047 | 0.6615 | 49597
