In [2]:
import torch 
import dlc_practical_prologue as prologue

In [3]:
def nearest_classification(train_input, train_target, x):
    dist = (train_input - x).pow(2).sum(1).view(-1)
    _, n = torch.min(dist, 0)
    
    return train_target[n.item()]

In [4]:
def compute_nb_errors(train_input, train_target, test_input, test_target, mean=None, proj=None):
    if mean is not None:
        train_input = train_input - mean
        test_input = test_input - mean
        
    if proj is not None:
        train_input = train_input @ proj.t()
        test_input = test_input @ proj.t()
        
    nb_errors = 0
    
    for n in range(test_input.size(0)):
        if test_target[n] != nearest_classification(train_input, train_target, test_input[n]):
            nb_errors = nb_errors + 1
            
    return nb_errors

In [5]:
def PCA(x):
    mean = x.mean()
    b = x - mean
    Sigma = b.t() @ b 
    eigen_values, eigen_vectors = Sigma.eig(True)
    right_order = eigen_values[:, 0].abs().sort(descending=True)[1]
    eigen_vectors = eigen_vectors.t()[right_order]
    
    return mean, eigen_vectors

In [6]:
for c in [False, True]:
    train_input, train_target, test_input, test_target = prologue.load_data(cifar=c)
    
    nb_errors = compute_nb_errors(train_input, train_target, test_input, test_target)
    print('Baseline nb_errors {:d} error {:.02f}%'.format(nb_errors, 100*nb_errors/test_input.size(0)))
    
    basis = train_input.new(100, train_input.size(1)).normal_()
    nb_errors = compute_nb_errors(train_input, train_target, test_input, test_target, None, basis)
    print('Random {:d}d nb_errors {:d} error {:.02f}%'.format(basis.size(0), nb_errors, 100 * nb_errors / test_input.size(0)))

    mean, basis = PCA(train_input)
    for d in [3, 10, 50, 100]:
        nb_errors = compute_nb_errors(train_input, train_target, test_input, test_target, mean, basis[:d])
        print('PCA {:d}d nb_errors {:d} error {:.02f}%'.format(basis.size(0), nb_errors, 100 * nb_errors / test_input.size(0)))

* Using MNIST
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
Baseline nb_errors 172 error 17.20%
Random 100d nb_errors 212 error 21.20%
PCA 784d nb_errors 574 error 57.40%
PCA 784d nb_errors 204 error 20.40%
PCA 784d nb_errors 156 error 15.60%
PCA 784d nb_errors 164 error 16.40%
* Using CIFAR
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10/
Files already downloaded and verified
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples
Baseline nb_errors 746 error 74.60%
Random 100d nb_errors 779 error 77.90%
PCA 3072d nb_errors 830 error 83.00%
PCA 3072d nb_errors 757 error 75.70%
PCA 3072d nb_errors 737 error 73.70%
PCA 3072d nb_errors 743 error 74.30%
