In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import RandomizedSearchCV
import math
import torch.nn as nn

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)


class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()

    def forward(self, x):
        return x[:, :, :28, :28]

class AutoEncoder(nn.Module):
    def __init__(self, d_l):
        super().__init__()

        self.encoder = nn.Sequential( #784
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),
                nn.Linear(3136, d_l)
        )
        self.decoder = nn.Sequential(
                torch.nn.Linear(d_l, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x7x7
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1), # 64x7x7 -> 64x13x13
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0), # 64x13x13 -> 32x27x27
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), # 32x27x27 -> 1x29x29
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid()
                )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    def get_latent_space(self, x):
        return self.encoder(x)

    def get_decoded_images(self, x):
        return self.decoder(x)

class CTDataset_AE(Dataset):
    def __init__(self, filepath, AE_model):
        self.x, self.y = torch.load(filepath, weights_only=False)
        self.x = self.x / 255.
        self.x = self.x.reshape(-1, 1, 28, 28).cuda().detach()
        AE_model.eval()
        with torch.no_grad():
            self.x = AE_model.get_latent_space(self.x)
        self.x = self.x.detach()
        self.y = F.one_hot(self.y, num_classes=10).to(float)
    def __len__(self):
        return self.x.shape[0]
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]

torch.manual_seed(42)
for latent_space_dim in [3, 4, 5, 6, 7, 8, 10, 12, 15, 16, 17, 20, 30, 40, 100]:
    AE_model = AutoEncoder(d_l = latent_space_dim)
    AE_model.load_state_dict(torch.load('./AE_'+ str(latent_space_dim) +'.pth', weights_only=True ))
    AE_model.to(device)

    train_ds = CTDataset_AE('./training.pt', AE_model)
    test_ds = CTDataset_AE('./test.pt', AE_model)

    torch.manual_seed(42)
    train_AE_set, train_cond_gen_set = torch.utils.data.random_split(train_ds, [30000, 30000])

    x_train, y_train = train_cond_gen_set[:]
    x_test, y_test = test_ds[:]

    x_train, y_train = x_train.cpu().detach().numpy(), y_train.cpu().detach().numpy()
    x_test, y_test = x_test.cpu().detach().numpy(), y_test.cpu().detach().numpy()

    classifiers = []
    param_dist = {
        "n_neighbors":
        [i for i in range(1, int(math.sqrt(x_train.shape[0])))]
    }
    random_search = RandomizedSearchCV(
        KNeighborsClassifier(),
        param_distributions=param_dist,
        n_iter=60,
        cv=5,
        n_jobs=-1,
        random_state=42 )
    classifiers.append(random_search.fit(x_train, y_train))

    # Assuming x_test and y_test are already defined
    test_accuracy = classifiers[0].score(x_test, y_test)
    print(f" latent_space_dim: {latent_space_dim}; Test Accuracy: {test_accuracy}")

 latent_space_dim: 3; Test Accuracy: 0.7768
 latent_space_dim: 4; Test Accuracy: 0.8843
 latent_space_dim: 5; Test Accuracy: 0.9229
 latent_space_dim: 6; Test Accuracy: 0.923
 latent_space_dim: 7; Test Accuracy: 0.9335
 latent_space_dim: 8; Test Accuracy: 0.9434
 latent_space_dim: 10; Test Accuracy: 0.9607
 latent_space_dim: 12; Test Accuracy: 0.9642
 latent_space_dim: 15; Test Accuracy: 0.9684
 latent_space_dim: 16; Test Accuracy: 0.9709
 latent_space_dim: 17; Test Accuracy: 0.9719
 latent_space_dim: 20; Test Accuracy: 0.9774
 latent_space_dim: 30; Test Accuracy: 0.9815
 latent_space_dim: 40; Test Accuracy: 0.9797
 latent_space_dim: 100; Test Accuracy: 0.9665


In [None]:
# 0.7768, 0.8843, 0.9229, 0.923, 0.9335, 0.9434, 0.9607, 0.9642, 0.9684, 0.9709, 0.9719, 0.9774, 0.9815, 0.9797, 0.9665