In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

from src import *

from tqdm import tqdm
import pickle

external_path=''

In [2]:
model=CNN()
model.load_state_dict(torch.load('cnn_mnist.pth'))
model.eval()

CNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=64, bias=True)
  )
  (out_layer): Linear(in_features=64, out_features=10, bias=True)
)

In [3]:
batch_size=1

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
correctly_classified_test_indices={digit:[] for digit in range(10)}
pbar=tqdm(enumerate(test_loader),total=len(test_loader))
for i,(img,label) in pbar:
    if model(img).argmax()==label:
        correctly_classified_test_indices[label.item()]+=[i]

correctly_classified_test_indices_file=open(f'{external_path}\\correctly_classified_test_indices','wb')
pickle.dump(correctly_classified_test_indices,correctly_classified_test_indices_file)
correctly_classified_test_indices_file.close()

100%|██████████| 10000/10000 [00:16<00:00, 609.41it/s]


In [6]:
pbar=tqdm(range(10))
for digit in pbar:
    digit_subset=Subset(test_dataset,indices=correctly_classified_test_indices[digit])
    digit_subset_loader=DataLoader(digit_subset,batch_size=batch_size)
    first_instance=True
    for i,(img,label) in enumerate(digit_subset_loader):
        pbar.set_description(f'Digit {digit}: {i+1}/{len(digit_subset_loader)}')
        if first_instance:
            digit_latents=model.encoder(img)
            first_instance=False
        else:
            digit_latents=torch.cat([digit_latents,model.encoder(img)])
    torch.save(digit_latents,f'{external_path}\\latent_activations\\{digit}.pt')

Digit 9: 985/985: 100%|██████████| 10/10 [00:47<00:00,  4.74s/it] 
