In [1]:
from pl_bolts.datamodules import CIFAR10DataModule, TinyCIFAR10DataModule
from torchvision import transforms
import torchvision.datasets as datasets
import wandb
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torchvision.utils import make_grid
import numpy as np
from matplotlib.pyplot import imshow, figure, clf
from pytorch_lightning.loggers import WandbLogger
from argparse import ArgumentParser
from pl_bolts.models.autoencoders.components import (
    resnet18_decoder,
    resnet18_encoder,
)
import torch
from torch import nn
import pytorch_lightning as pl
from numberclassifier import Number_Classifier
from tqdm import trange

pl.seed_everything(1234)
logger = WandbLogger(name="VAE", project="VAETesting")

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:

tensor_transform = transforms.ToTensor()
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=tensor_transform
)
mnist_testset = datasets.MNIST(
    root="./data", train=False, download=True, transform=tensor_transform
)
maxtrain = torch.max(mnist_trainset.data)
maxtest = torch.max(mnist_testset.data)
# load into torch datasets
# To speed up training, we subset to 10,000 (instead of 60,000) images. You can change this if you want better performance.
train_dataset = torch.utils.data.TensorDataset(mnist_trainset.data.to(
    dtype=torch.float32)[:10000]/maxtrain, mnist_trainset.targets.to(dtype=torch.long)[:10000])
test_dataset = torch.utils.data.TensorDataset(mnist_testset.data.to(
    dtype=torch.float32)/maxtest, mnist_testset.targets.to(dtype=torch.long))




def get_accuracy(output, targets):
    output = output.detach()  # this removes the gradients associated with the tensor
    predicted = output.argmax(-1)
    correct = (predicted == targets).sum().item()
    accuracy = correct / output.size(0) * 100
    return accuracy


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
figure(figsize=(8, 3), dpi=300)


class Number_Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(784, 20)
        # These are PyTorch's predefined layers. Each is a class. In the "init" function, we just initialize and instantiate the classes, creating objects (that behave like functions)
        self.layer2 = nn.Linear(20, 10)
        self.nonlin = nn.ReLU()
        # self.softmax = nn.Softmax() # Converts numbers into probabilities

    def forward(self, x):
        x = self.layer1(x)  # Composing the functions we created below
        x = self.nonlin(x)
        x = self.layer2(x)
        return x


def get_test_accuracy(x_hat):
    weight_1_matrix, weight_2_matrix, bias_1_matrix, bias_2_matrix = torch.split(
                x_hat, (20 * 784, 10 * 20, 20, 10))
    classifier = Number_Classifier()
    classifier.to(device)
    # train_ep_pred = classifier(mnist_trainset.data.to(dtype=torch.float32).reshape(-1,28*28).to(device))
    # test_ep_pred = classifier(mnist_testset.data.to(dtype=torch.float32).reshape(-1,28*28).to(device))

    # train_accuracy = get_accuracy(train_ep_pred.cpu(), mnist_trainset.targets.to(dtype=torch.long))
    # test_accuracy = get_accuracy(test_ep_pred.cpu(), mnist_testset.targets.to(dtype=torch.long))
    # print("reconst test_accuracy", test_accuracy, "reconst train_accuracy", train_accuracy)
    classifier_load = {}
    classifier_load["layer1.weight"] = weight_1_matrix.reshape((20,784))
    classifier_load["layer2.weight"] = weight_2_matrix.reshape((10,20))
    classifier_load["layer1.bias"] = bias_1_matrix
    classifier_load["layer2.bias"] = bias_2_matrix

    classifier.load_state_dict(classifier_load, strict=True)
    test_ep_pred = classifier(mnist_testset.data.to(dtype=torch.float32).reshape(-1,28*28).to(device))

    test_accuracy = get_accuracy(test_ep_pred.cpu(), mnist_testset.targets.to(dtype=torch.long))
    return test_accuracy


model_params = []

for i in trange(300):
    classifier_load = torch.load(
        "new_data/model_{}.pth".format(i+1), map_location='cuda')
    # print(classifier_load)
    weight_1_matrix = classifier_load["layer1.weight"]
    weight_2_matrix = classifier_load["layer2.weight"]
    bias_1_matrix = classifier_load["layer1.bias"]
    bias_2_matrix = classifier_load["layer2.bias"]
   
    
    weight_1_matrix = weight_1_matrix.reshape(20*784)
    weight_2_matrix = weight_2_matrix.reshape(200)

    

    #bias_1_matrix = bias_1_matrix.reshpae(20)
    #bias_2_matrix = bias_2_matrix.reshpae(19)
    combo = torch.cat((weight_1_matrix, weight_2_matrix,
                      bias_1_matrix, bias_2_matrix))
    print("combo", combo.shape)

    model_params.append(combo.reshape(1,15910))
    
    print("reconst test_accuracy", get_test_accuracy(combo))
    del weight_1_matrix
    del weight_2_matrix
    del bias_1_matrix
    del bias_2_matrix
    del classifier_load
    # del combo
print("modelparams", len(model_params))



  0%|          | 0/300 [00:00<?, ?it/s]

combo torch.Size([15910])


  1%|▏         | 4/300 [00:04<03:53,  1.27it/s]

reconst test_accuracy 80.02
combo torch.Size([15910])
reconst test_accuracy 80.13
combo torch.Size([15910])
reconst test_accuracy 80.14
combo torch.Size([15910])
reconst test_accuracy 80.10000000000001
combo torch.Size([15910])
reconst test_accuracy 80.05
combo torch.Size([15910])


  3%|▎         | 8/300 [00:04<01:36,  3.03it/s]

reconst test_accuracy 80.08
combo torch.Size([15910])
reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 80.10000000000001
combo torch.Size([15910])
reconst test_accuracy 80.07
combo torch.Size([15910])
reconst test_accuracy 80.11
combo torch.Size([15910])
reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 80.03


  5%|▌         | 16/300 [00:04<00:36,  7.82it/s]

combo torch.Size([15910])
reconst test_accuracy 80.08
combo torch.Size([15910])
reconst test_accuracy 80.10000000000001
combo torch.Size([15910])
reconst test_accuracy 81.08999999999999
combo torch.Size([15910])
reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 80.07
combo torch.Size([15910])
reconst test_accuracy 81.8
combo torch.Size([15910])


  8%|▊         | 24/300 [00:04<00:19, 13.85it/s]

reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 81.67999999999999
combo torch.Size([15910])
reconst test_accuracy 80.52
combo torch.Size([15910])
reconst test_accuracy 80.41
combo torch.Size([15910])
reconst test_accuracy 81.34
combo torch.Size([15910])


  9%|▉         | 28/300 [00:04<00:15, 17.08it/s]

reconst test_accuracy 80.72
combo torch.Size([15910])
reconst test_accuracy 80.12
combo torch.Size([15910])
reconst test_accuracy 80.08999999999999
combo torch.Size([15910])
reconst test_accuracy 80.06
combo torch.Size([15910])
reconst test_accuracy 80.38
combo torch.Size([15910])
reconst test_accuracy 81.87
combo torch.Size([15910])
reconst test_accuracy 80.05
combo torch.Size([15910])


 12%|█▏        | 36/300 [00:05<00:11, 23.03it/s]

reconst test_accuracy 85.11
combo torch.Size([15910])
reconst test_accuracy 85.21
combo torch.Size([15910])
reconst test_accuracy 80.47
combo torch.Size([15910])
reconst test_accuracy 80.34
combo torch.Size([15910])
reconst test_accuracy 86.16
combo torch.Size([15910])
reconst test_accuracy 85.16
combo torch.Size([15910])
reconst test_accuracy 86.46000000000001
combo torch.Size([15910])


 15%|█▍        | 44/300 [00:05<00:09, 27.03it/s]

reconst test_accuracy 85.33
combo torch.Size([15910])
reconst test_accuracy 86.82
combo torch.Size([15910])
reconst test_accuracy 84.91
combo torch.Size([15910])
reconst test_accuracy 85.25
combo torch.Size([15910])
reconst test_accuracy 85.58
combo torch.Size([15910])
reconst test_accuracy 84.78999999999999
combo torch.Size([15910])
reconst test_accuracy 80.27
combo torch.Size([15910])


 17%|█▋        | 52/300 [00:05<00:08, 28.80it/s]

reconst test_accuracy 81.41000000000001
combo torch.Size([15910])
reconst test_accuracy 85.99
combo torch.Size([15910])
reconst test_accuracy 86.13
combo torch.Size([15910])
reconst test_accuracy 81.01
combo torch.Size([15910])
reconst test_accuracy 85.86
combo torch.Size([15910])
reconst test_accuracy 85.95
combo torch.Size([15910])
reconst test_accuracy 84.94


 19%|█▊        | 56/300 [00:05<00:08, 29.50it/s]

combo torch.Size([15910])
reconst test_accuracy 86.42
combo torch.Size([15910])
reconst test_accuracy 81.82000000000001
combo torch.Size([15910])
reconst test_accuracy 80.02
combo torch.Size([15910])
reconst test_accuracy 80.17
combo torch.Size([15910])
reconst test_accuracy 86.3
combo torch.Size([15910])


 20%|██        | 60/300 [00:05<00:08, 29.41it/s]

reconst test_accuracy 85.38
combo torch.Size([15910])
reconst test_accuracy 84.92
combo torch.Size([15910])
reconst test_accuracy 80.42
combo torch.Size([15910])
reconst test_accuracy 85.95
combo torch.Size([15910])
reconst test_accuracy 85.22
combo torch.Size([15910])
reconst test_accuracy 80.39
combo torch.Size([15910])


 22%|██▏       | 67/300 [00:06<00:08, 28.37it/s]

reconst test_accuracy 80.72
combo torch.Size([15910])
reconst test_accuracy 83.26
combo torch.Size([15910])
reconst test_accuracy 86.15
combo torch.Size([15910])
reconst test_accuracy 81.22
combo torch.Size([15910])
reconst test_accuracy 81.06
combo torch.Size([15910])
reconst test_accuracy 85.39
combo torch.Size([15910])


 24%|██▍       | 73/300 [00:06<00:07, 28.60it/s]

reconst test_accuracy 85.95
combo torch.Size([15910])
reconst test_accuracy 80.62
combo torch.Size([15910])
reconst test_accuracy 85.57000000000001
combo torch.Size([15910])
reconst test_accuracy 86.71
combo torch.Size([15910])
reconst test_accuracy 86.67
combo torch.Size([15910])
reconst test_accuracy 80.83
combo torch.Size([15910])
reconst test_accuracy 80.44


 27%|██▋       | 80/300 [00:06<00:07, 29.50it/s]

combo torch.Size([15910])
reconst test_accuracy 84.89
combo torch.Size([15910])
reconst test_accuracy 85.8
combo torch.Size([15910])
reconst test_accuracy 81.28999999999999
combo torch.Size([15910])
reconst test_accuracy 81.98
combo torch.Size([15910])
reconst test_accuracy 80.82000000000001
combo torch.Size([15910])
reconst test_accuracy 85.61
combo torch.Size([15910])


 29%|██▉       | 88/300 [00:06<00:06, 30.41it/s]

reconst test_accuracy 85.31
combo torch.Size([15910])
reconst test_accuracy 80.57
combo torch.Size([15910])
reconst test_accuracy 85.92999999999999
combo torch.Size([15910])
reconst test_accuracy 86.31
combo torch.Size([15910])
reconst test_accuracy 80.03
combo torch.Size([15910])
reconst test_accuracy 86.59
combo torch.Size([15910])


 31%|███       | 92/300 [00:06<00:06, 30.38it/s]

reconst test_accuracy 80.71000000000001
combo torch.Size([15910])
reconst test_accuracy 80.28
combo torch.Size([15910])
reconst test_accuracy 85.00999999999999
combo torch.Size([15910])
reconst test_accuracy 84.67
combo torch.Size([15910])
reconst test_accuracy 85.11999999999999
combo torch.Size([15910])
reconst test_accuracy 82.32000000000001
combo torch.Size([15910])
reconst test_accuracy 85.61999999999999
combo torch.Size([15910])


 33%|███▎      | 100/300 [00:07<00:06, 29.82it/s]

reconst test_accuracy 80.62
combo torch.Size([15910])
reconst test_accuracy 84.88
combo torch.Size([15910])
reconst test_accuracy 86.78
combo torch.Size([15910])
reconst test_accuracy 85.99
combo torch.Size([15910])
reconst test_accuracy 86.0
combo torch.Size([15910])
reconst test_accuracy 86.0
combo torch.Size([15910])
reconst test_accuracy 85.81
combo torch.Size([15910])


 36%|███▌      | 108/300 [00:07<00:06, 29.84it/s]

reconst test_accuracy 80.61
combo torch.Size([15910])
reconst test_accuracy 85.47
combo torch.Size([15910])
reconst test_accuracy 84.89999999999999
combo torch.Size([15910])
reconst test_accuracy 80.25
combo torch.Size([15910])
reconst test_accuracy 86.05000000000001
combo torch.Size([15910])
reconst test_accuracy 86.1
combo torch.Size([15910])
reconst test_accuracy 86.41
combo torch.Size([15910])


 37%|███▋      | 112/300 [00:07<00:06, 30.32it/s]

reconst test_accuracy 81.94
combo torch.Size([15910])
reconst test_accuracy 81.74
combo torch.Size([15910])
reconst test_accuracy 80.42
combo torch.Size([15910])
reconst test_accuracy 83.66
combo torch.Size([15910])
reconst test_accuracy 81.35
combo torch.Size([15910])
reconst test_accuracy 86.56
combo torch.Size([15910])


 40%|████      | 120/300 [00:07<00:05, 30.41it/s]

reconst test_accuracy 85.96000000000001
combo torch.Size([15910])
reconst test_accuracy 85.74000000000001
combo torch.Size([15910])
reconst test_accuracy 81.58
combo torch.Size([15910])
reconst test_accuracy 81.2
combo torch.Size([15910])
reconst test_accuracy 84.27
combo torch.Size([15910])
reconst test_accuracy 80.94
combo torch.Size([15910])
reconst test_accuracy 81.08
combo torch.Size([15910])


 41%|████▏     | 124/300 [00:08<00:05, 30.73it/s]

reconst test_accuracy 81.17999999999999
combo torch.Size([15910])
reconst test_accuracy 85.42
combo torch.Size([15910])
reconst test_accuracy 81.67999999999999
combo torch.Size([15910])
reconst test_accuracy 81.22
combo torch.Size([15910])
reconst test_accuracy 80.52
combo torch.Size([15910])
reconst test_accuracy 86.41


 44%|████▍     | 132/300 [00:08<00:05, 29.88it/s]

combo torch.Size([15910])
reconst test_accuracy 82.83
combo torch.Size([15910])
reconst test_accuracy 82.22
combo torch.Size([15910])
reconst test_accuracy 85.5
combo torch.Size([15910])
reconst test_accuracy 86.5
combo torch.Size([15910])
reconst test_accuracy 80.58999999999999
combo torch.Size([15910])
reconst test_accuracy 85.15
combo torch.Size([15910])


 45%|████▌     | 136/300 [00:08<00:05, 30.21it/s]

reconst test_accuracy 80.04
combo torch.Size([15910])
reconst test_accuracy 84.97
combo torch.Size([15910])
reconst test_accuracy 85.97
combo torch.Size([15910])
reconst test_accuracy 85.38
combo torch.Size([15910])
reconst test_accuracy 81.89999999999999
combo torch.Size([15910])
reconst test_accuracy 85.76


 48%|████▊     | 143/300 [00:08<00:05, 28.43it/s]

combo torch.Size([15910])
reconst test_accuracy 80.61
combo torch.Size([15910])
reconst test_accuracy 81.87
combo torch.Size([15910])
reconst test_accuracy 85.65
combo torch.Size([15910])
reconst test_accuracy 85.32
combo torch.Size([15910])
reconst test_accuracy 80.83
combo torch.Size([15910])
reconst test_accuracy 80.69


 50%|█████     | 150/300 [00:08<00:05, 29.72it/s]

combo torch.Size([15910])
reconst test_accuracy 81.67
combo torch.Size([15910])
reconst test_accuracy 80.60000000000001
combo torch.Size([15910])
reconst test_accuracy 86.09
combo torch.Size([15910])
reconst test_accuracy 85.48
combo torch.Size([15910])
reconst test_accuracy 80.76
combo torch.Size([15910])
reconst test_accuracy 85.63
combo torch.Size([15910])


 52%|█████▏    | 157/300 [00:09<00:04, 29.68it/s]

reconst test_accuracy 80.36999999999999
combo torch.Size([15910])
reconst test_accuracy 85.34
combo torch.Size([15910])
reconst test_accuracy 86.06
combo torch.Size([15910])
reconst test_accuracy 85.92
combo torch.Size([15910])
reconst test_accuracy 80.92
combo torch.Size([15910])
reconst test_accuracy 85.83
combo torch.Size([15910])


 54%|█████▎    | 161/300 [00:09<00:04, 29.31it/s]

reconst test_accuracy 84.84
combo torch.Size([15910])
reconst test_accuracy 86.14
combo torch.Size([15910])
reconst test_accuracy 85.1
combo torch.Size([15910])
reconst test_accuracy 86.38
combo torch.Size([15910])
reconst test_accuracy 80.67999999999999
combo torch.Size([15910])
reconst test_accuracy 80.05


 56%|█████▌    | 167/300 [00:09<00:04, 28.34it/s]

combo torch.Size([15910])
reconst test_accuracy 85.99
combo torch.Size([15910])
reconst test_accuracy 86.83
combo torch.Size([15910])
reconst test_accuracy 84.84
combo torch.Size([15910])
reconst test_accuracy 86.33
combo torch.Size([15910])
reconst test_accuracy 83.11
combo torch.Size([15910])


 58%|█████▊    | 175/300 [00:09<00:04, 30.60it/s]

reconst test_accuracy 86.74
combo torch.Size([15910])
reconst test_accuracy 86.0
combo torch.Size([15910])
reconst test_accuracy 81.56
combo torch.Size([15910])
reconst test_accuracy 86.68
combo torch.Size([15910])
reconst test_accuracy 84.73
combo torch.Size([15910])
reconst test_accuracy 86.08
combo torch.Size([15910])
reconst test_accuracy 80.95
combo torch.Size([15910])


 60%|█████▉    | 179/300 [00:09<00:03, 30.36it/s]

reconst test_accuracy 80.4
combo torch.Size([15910])
reconst test_accuracy 86.1
combo torch.Size([15910])
reconst test_accuracy 85.82
combo torch.Size([15910])
reconst test_accuracy 85.24000000000001
combo torch.Size([15910])
reconst test_accuracy 80.41
combo torch.Size([15910])
reconst test_accuracy 86.44
combo torch.Size([15910])
reconst test_accuracy 85.52


 62%|██████▏   | 187/300 [00:10<00:03, 30.69it/s]

combo torch.Size([15910])
reconst test_accuracy 86.33
combo torch.Size([15910])
reconst test_accuracy 80.46
combo torch.Size([15910])
reconst test_accuracy 81.46
combo torch.Size([15910])
reconst test_accuracy 85.96000000000001
combo torch.Size([15910])
reconst test_accuracy 81.08
combo torch.Size([15910])


 64%|██████▎   | 191/300 [00:10<00:03, 30.21it/s]

reconst test_accuracy 82.07
combo torch.Size([15910])
reconst test_accuracy 80.08999999999999
combo torch.Size([15910])
reconst test_accuracy 85.85000000000001
combo torch.Size([15910])
reconst test_accuracy 84.57000000000001
combo torch.Size([15910])
reconst test_accuracy 86.08
combo torch.Size([15910])
reconst test_accuracy 80.34
combo torch.Size([15910])


 66%|██████▋   | 199/300 [00:10<00:03, 30.91it/s]

reconst test_accuracy 86.36
combo torch.Size([15910])
reconst test_accuracy 86.42
combo torch.Size([15910])
reconst test_accuracy 85.54
combo torch.Size([15910])
reconst test_accuracy 85.59
combo torch.Size([15910])
reconst test_accuracy 81.52000000000001
combo torch.Size([15910])
reconst test_accuracy 85.72
combo torch.Size([15910])
reconst test_accuracy 85.6
combo torch.Size([15910])


 68%|██████▊   | 203/300 [00:10<00:03, 29.89it/s]

reconst test_accuracy 86.32
combo torch.Size([15910])
reconst test_accuracy 80.28
combo torch.Size([15910])
reconst test_accuracy 83.46000000000001
combo torch.Size([15910])
reconst test_accuracy 86.33999999999999
combo torch.Size([15910])
reconst test_accuracy 80.2
combo torch.Size([15910])


 70%|███████   | 210/300 [00:11<00:03, 26.38it/s]

reconst test_accuracy 80.77
combo torch.Size([15910])
reconst test_accuracy 85.92
combo torch.Size([15910])
reconst test_accuracy 86.74
combo torch.Size([15910])
reconst test_accuracy 84.08
combo torch.Size([15910])
reconst test_accuracy 86.79
combo torch.Size([15910])
reconst test_accuracy 85.37
combo torch.Size([15910])


 72%|███████▏  | 217/300 [00:11<00:02, 27.93it/s]

reconst test_accuracy 85.11999999999999
combo torch.Size([15910])
reconst test_accuracy 85.76
combo torch.Size([15910])
reconst test_accuracy 86.06
combo torch.Size([15910])
reconst test_accuracy 86.9
combo torch.Size([15910])
reconst test_accuracy 80.08
combo torch.Size([15910])
reconst test_accuracy 81.72
combo torch.Size([15910])


 74%|███████▍  | 223/300 [00:11<00:02, 28.66it/s]

reconst test_accuracy 81.89
combo torch.Size([15910])
reconst test_accuracy 81.2
combo torch.Size([15910])
reconst test_accuracy 85.32
combo torch.Size([15910])
reconst test_accuracy 86.79
combo torch.Size([15910])
reconst test_accuracy 85.66
combo torch.Size([15910])
reconst test_accuracy 85.75
combo torch.Size([15910])


 77%|███████▋  | 230/300 [00:11<00:02, 28.64it/s]

reconst test_accuracy 81.35
combo torch.Size([15910])
reconst test_accuracy 86.4
combo torch.Size([15910])
reconst test_accuracy 86.29
combo torch.Size([15910])
reconst test_accuracy 86.6
combo torch.Size([15910])
reconst test_accuracy 80.65
combo torch.Size([15910])
reconst test_accuracy 84.76
combo torch.Size([15910])


 78%|███████▊  | 233/300 [00:11<00:02, 27.39it/s]

reconst test_accuracy 81.44
combo torch.Size([15910])
reconst test_accuracy 86.22999999999999
combo torch.Size([15910])
reconst test_accuracy 80.78999999999999
combo torch.Size([15910])
reconst test_accuracy 85.52
combo torch.Size([15910])
reconst test_accuracy 85.19
combo torch.Size([15910])


 80%|███████▉  | 239/300 [00:12<00:02, 25.56it/s]

reconst test_accuracy 86.24000000000001
combo torch.Size([15910])
reconst test_accuracy 86.68
combo torch.Size([15910])
reconst test_accuracy 85.95
combo torch.Size([15910])
reconst test_accuracy 86.99
combo torch.Size([15910])
reconst test_accuracy 85.63
combo torch.Size([15910])


 82%|████████▏ | 245/300 [00:12<00:02, 25.85it/s]

reconst test_accuracy 85.25
combo torch.Size([15910])
reconst test_accuracy 86.48
combo torch.Size([15910])
reconst test_accuracy 82.38
combo torch.Size([15910])
reconst test_accuracy 83.00999999999999
combo torch.Size([15910])
reconst test_accuracy 86.00999999999999
combo torch.Size([15910])
reconst test_accuracy 86.11
combo torch.Size([15910])


 84%|████████▎ | 251/300 [00:12<00:01, 26.69it/s]

reconst test_accuracy 83.19
combo torch.Size([15910])
reconst test_accuracy 86.15
combo torch.Size([15910])
reconst test_accuracy 86.04
combo torch.Size([15910])
reconst test_accuracy 85.6
combo torch.Size([15910])
reconst test_accuracy 80.84
combo torch.Size([15910])
reconst test_accuracy 80.44
combo torch.Size([15910])


 86%|████████▌ | 257/300 [00:12<00:01, 26.04it/s]

reconst test_accuracy 85.11
combo torch.Size([15910])
reconst test_accuracy 85.47
combo torch.Size([15910])
reconst test_accuracy 84.71
combo torch.Size([15910])
reconst test_accuracy 85.86
combo torch.Size([15910])
reconst test_accuracy 84.83000000000001
combo torch.Size([15910])
reconst test_accuracy 85.65
combo torch.Size([15910])


 88%|████████▊ | 263/300 [00:12<00:01, 25.92it/s]

reconst test_accuracy 81.96
combo torch.Size([15910])
reconst test_accuracy 80.02
combo torch.Size([15910])
reconst test_accuracy 86.53999999999999
combo torch.Size([15910])
reconst test_accuracy 81.84
combo torch.Size([15910])
reconst test_accuracy 80.77
combo torch.Size([15910])


 89%|████████▊ | 266/300 [00:13<00:01, 25.64it/s]

reconst test_accuracy 80.25999999999999
combo torch.Size([15910])
reconst test_accuracy 85.48
combo torch.Size([15910])
reconst test_accuracy 86.07000000000001
combo torch.Size([15910])
reconst test_accuracy 85.99
combo torch.Size([15910])
reconst test_accuracy 85.24000000000001
combo torch.Size([15910])
reconst test_accuracy 85.28


 91%|█████████ | 272/300 [00:13<00:01, 25.89it/s]

combo torch.Size([15910])
reconst test_accuracy 85.47
combo torch.Size([15910])
reconst test_accuracy 86.05000000000001
combo torch.Size([15910])
reconst test_accuracy 84.86
combo torch.Size([15910])
reconst test_accuracy 80.61
combo torch.Size([15910])


 93%|█████████▎| 278/300 [00:13<00:00, 25.26it/s]

reconst test_accuracy 81.08999999999999
combo torch.Size([15910])
reconst test_accuracy 86.58
combo torch.Size([15910])
reconst test_accuracy 80.01
combo torch.Size([15910])
reconst test_accuracy 80.71000000000001
combo torch.Size([15910])
reconst test_accuracy 81.97
combo torch.Size([15910])


 94%|█████████▎| 281/300 [00:13<00:00, 25.12it/s]

reconst test_accuracy 81.25
combo torch.Size([15910])
reconst test_accuracy 85.85000000000001
combo torch.Size([15910])
reconst test_accuracy 86.53999999999999
combo torch.Size([15910])
reconst test_accuracy 85.48
combo torch.Size([15910])
reconst test_accuracy 85.8
combo torch.Size([15910])


 96%|█████████▌| 287/300 [00:13<00:00, 25.48it/s]

reconst test_accuracy 80.12
combo torch.Size([15910])
reconst test_accuracy 81.82000000000001
combo torch.Size([15910])
reconst test_accuracy 85.6
combo torch.Size([15910])
reconst test_accuracy 80.10000000000001
combo torch.Size([15910])
reconst test_accuracy 80.47
combo torch.Size([15910])
reconst test_accuracy 86.24000000000001
combo torch.Size([15910])


 98%|█████████▊| 293/300 [00:14<00:00, 24.79it/s]

reconst test_accuracy 85.94000000000001
combo torch.Size([15910])
reconst test_accuracy 83.23
combo torch.Size([15910])
reconst test_accuracy 80.42
combo torch.Size([15910])
reconst test_accuracy 85.16
combo torch.Size([15910])
reconst test_accuracy 85.57000000000001
combo torch.Size([15910])


100%|█████████▉| 299/300 [00:14<00:00, 23.80it/s]

reconst test_accuracy 85.64
combo torch.Size([15910])
reconst test_accuracy 83.39999999999999
combo torch.Size([15910])
reconst test_accuracy 81.17999999999999
combo torch.Size([15910])
reconst test_accuracy 85.36
combo torch.Size([15910])
reconst test_accuracy 85.02
combo torch.Size([15910])


100%|██████████| 300/300 [00:14<00:00, 20.71it/s]

reconst test_accuracy 81.86
modelparams 300





<Figure size 2400x900 with 0 Axes>

In [5]:
class WeightsDataset(torch.utils.data.Dataset):
    def __init__(self, weights, transform=None, target_transform=None):
        self.weight_data = weights
        self.img_labels = []  # pd.read_csv(annotations_file)
        indices = []
        for i in range(len(weights)):
            self.img_labels.append(0)
            indices.append(i)

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = self.weight_data[idx].to(device)
        label = (idx//10) % 10
        return image, torch.tensor(label).to(device)



weights_data_set = WeightsDataset(model_params)

weights_data_loader = torch.utils.data.DataLoader(
    weights_data_set, batch_size=20, shuffle=True
)

In [3]:

class ImageSampler(pl.Callback):
    def __init__(self):
        super().__init__()
        self.img_size = None
        self.num_preds = 16

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        clf()
        # Z COMES FROM NORMAL(0, 1)
        rand_v = torch.randn(
            (self.num_preds, pl_module.hparams.latent_dim), device=pl_module.device)
        p = torch.distributions.Normal(
            torch.zeros_like(rand_v), torch.ones_like(rand_v))
        z = p.rsample()

        # SAMPLE IMAGES
        with torch.no_grad():
            pred = pl_module.decoder(z.to(pl_module.device)).cpu()
        print("pred", pred.shape)
        pred = pred.reshape((16, 2*43, 5*37))
        # UNDO DATA NORMALIZATION
        normalize = cifar10_normalization()
        mean, std = np.array(normalize.mean), np.array(normalize.std)
        # img = make_grid(pred).numpy() #* std + mean
        samples = [wandb.Image(img) for img in pred]
        # PLOT IMAGES
        wandb.log({"images": samples})



    # reconstruction loss


class NumberSampler(pl.Callback):
    def __init__(self):
        super().__init__()
        self.img_size = None
        self.num_preds = 16

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        clf()
        # Z COMES FROM NORMAL(0, 1)
        rand_v = torch.randn(
            (self.num_preds, pl_module.hparams.latent_dim), device=pl_module.device)
        p = torch.distributions.Normal(
            torch.zeros_like(rand_v), torch.ones_like(rand_v))
        z = p.rsample()

        # SAMPLE IMAGES
        with torch.no_grad():
            pred = pl_module.decoder(z.to(pl_module.device)).cpu()
        test_accuracy_out = []
        for combo in pred:
            print(combo.shape)
            test_accuracy = get_test_accuracy(combo)
            print("add weights test_accuracy", test_accuracy)
            test_accuracy_out.append(test_accuracy)

        
        sumnp = pred.numpy().sum()
        print("sumnp", sumnp)
        avgtest = np.average(np.array(test_accuracy_out))
        print("avg test_accuracy", avgtest)
        wandb.log({"test_accuracy": avgtest})
        wandb.log({"sumofweights": sumnp})


# For the afficianados: the (nn.Module) subclasses PyTorch's neural network superclass, which, when initialized below...
class Encode(nn.Module):
    def __init__(self, input_height):
        # <--- does a bunch of janitorial work to make our network easier to use. For example, once our network is initialized, calling IttyBittyNetwork(data) passes the data into forward.
        super().__init__()
        self.input_height = input_height
        self.encoding_model = nn.Sequential(
            nn.Linear(input_height, 400),
            nn.ReLU(),
            # Initializing the classes adds their free variables to our network's list of parameters to update via gradient descent.
            nn.Linear(400, 200),
            nn.ReLU(),
            # Initializing the classes adds their free variables to our network's list of parameters to update via gradient descent.
            nn.Linear(200, 100),
            nn.ReLU(),

            # nn.ReLU(),
            # nn.Linear(100, 50),  # Note this ends with a 10 dimensional output.
            # nn.ReLU(),
            # nn.Linear(50, 40),
        )
        # self.softmax = nn.Softmax() # Converts numbers into probabilities

    def encode(self, x):
        # print("shape", x.shape)
        # x = torch.reshape(x, (-1, self.input_height * self.input_height * 3))
        x = self.encoding_model(x)
        return x

    def forward(self, x):
        y = self.encode(x)
        return y


# For the afficianados: the (nn.Module) subclasses PyTorch's neural network superclass, which, when initialized below...
class Decode(nn.Module):
    def __init__(self, input_height):
        # <--- does a bunch of janitorial work to make our network easier to use. For example, once our network is initialized, calling IttyBittyNetwork(data) passes the data into forward.
        super().__init__()
        self.nonlin = nn.ReLU()
        # self.softmax = nn.Softmax() # Converts numbers into probabilities
        self.decoding_model = nn.Sequential(
            # nn.Linear(20, 50),
            # nn.ReLU(),
            nn.Linear(50, 100),
            nn.ReLU(),
            nn.Linear(100, 200),
            nn.ReLU(),
            nn.Linear(200, 400),
            nn.ReLU(),
            nn.Linear(400, input_height),
            nn.ReLU(),
        )
        self.sig = nn.ReLU()
        self.input_height = input_height

    def decode(self, x):
        x = self.decoding_model(x)
        # x = torch.reshape(x, (-1, 3, self.input_height, self.input_height))
        return x

    def forward(self, x):
        x = self.decode(x)
        return x



class VAE(pl.LightningModule):
    def __init__(self, enc_out_dim=100, latent_dim=50, input_height=15910):
        super().__init__()

        self.save_hyperparameters()
        device = torch.device("cpu")
        self.encoder = Encode(input_height)
        self.encoder.to(device)
        self.decoder = Decode(input_height)
        self.decoder.to(device)
        print("cuda", device)
        # encoder, decoder
        # self.encoder = resnet18_encoder(False, False)
        # self.decoder = resnet18_decoder(
        #     latent_dim=latent_dim,
        #     input_height=input_height,
        #     first_conv=False,
        #     maxpool1=False
        # )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_mu.to(device)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var.to(device)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))
        self.log_scale.to(device)
        device

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        mean.to(device)
        scale.to(device)

        dist = torch.distributions.Normal(mean, scale)
        x.to(device)
        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        # print("log_pxz", log_pxz.shape)
        return log_pxz.sum(dim=(1))

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(
            torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x.to(device)

        # encode x to get the mu and variance parameters
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()

        # decoded
        x_hat = self.decoder(z)
        x_hat.to(device)
        for i in range(len(batch)):
            print("get_test_accuracy", get_test_accuracy(x_hat[i]), get_test_accuracy(x[i]))
            xhatavg, xavg, xratio = (np.average(x_hat[i].cpu().detach().numpy()) ,np.average(x[i].cpu().detach().numpy()),  np.sum(x_hat[i].cpu().detach().numpy())/np.sum(x[i].cpu().detach().numpy()))
            xhatstd, xstd = (np.std(x_hat[i].cpu().detach().numpy()) ,np.std(x[i].cpu().detach().numpy()))
            
            print("sumnp",xhatavg, xavg, xratio )
            print("sumnp",xhatstd, xstd )
            wandb.log({
                'xhatavg': xhatavg,
                'xavg': xavg,
                'xratio': xratio,
                'xstd': xstd,
                'xhatstd': xhatstd,
            })
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
        

        # kl
        kl = self.kl_divergence(z, mu, std)

        # elbo
        elbo = (kl - recon_loss)
        elbo = elbo.mean()

        wandb.log({
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(),
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        })

        return elbo








In [12]:
sampler = ImageSampler()
numsampler = NumberSampler()

vae = VAE()
trainer = pl.Trainer(gpus=1, logger=logger,
                        max_epochs=300, callbacks=[sampler],auto_lr_find=True)
trainer.fit(vae, weights_data_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


cuda cpu


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlukas-nel[0m ([33m2084experiments[0m). Use [1m`wandb login --relogin`[0m to force relogin


Problem at: c:\Users\lukas\Desktop\Math522\FinalProject\VAE\lib\site-packages\pytorch_lightning\loggers\wandb.py 127 experiment


KeyboardInterrupt: 

In [8]:
import torch
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D
class DiffModel(pl.LightningModule):
    def __init__(self,  input_height=15910):
        super().__init__()
        self.diffmodel = Unet1D(
            dim = 64,
            dim_mults = (1, 2, 4, 8),
            channels = 1
        )

        self.diffusion = GaussianDiffusion1D(
            self.diffmodel,
            seq_length = input_height,
            timesteps = 100,
            objective = 'pred_v'
        )
        print("loss", )

        self.save_hyperparameters()
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    
    def training_step(self, batch, batch_idx):
        x, _ = batch
        x.to(device)
        x  = x.reshape((-1,1,15910))
        
        loss = self.diffusion(x)
        wandb.log({
            'loss': loss,
        })
        print("loss", loss)
        return loss


# model = Unet1D(
#     dim = 64,
#     dim_mults = (1, 2, 4, 8),
#     channels = 1
# ).cuda()

# diffusion = GaussianDiffusion1D(
#     model.to(device),
#     seq_length = 15910,
#     timesteps = 10,
#     objective = 'pred_v'
# ).cuda()

# training_seq = torch.rand(8, 1, 15910) # features are normalized from 0 to 1


# after a lot of training
# for weights in weights_data_loader:
#     loss = diffusion(training_seq)
#     loss.backward()
# sampled_seq = diffusion.sample(batch_size = 4)
# sampled_seq.shape # (4, 32, 128)


# optimizer = torch.optim.Adam(diffusion.parameters(),
#                              lr=0.004,
#                              weight_decay=1e-8)

# epochs = 1
# outputs = []
# losses = []
# for epoch in range(epochs):
#     print("Epoch:", epoch)
#     diffusion.to(device)
#     diffusion.train()
#     for (batch, _) in weights_data_loader:
#         batch.to(device)
#         loss = diffusion(batch)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

# sampled_seq = diffusion.sample(batch_size = 4)
# sampled_seq.shape 

sampler = ImageSampler()
numsampler = NumberSampler()

dmmod = DiffModel()
trainer = pl.Trainer(gpus=1, logger=logger,
                        max_epochs=300, callbacks=[sampler],auto_lr_find=True)
trainer.fit(dmmod, weights_data_loader)

# loss = diffusion(training_images)
# loss.backward()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


loss



  | Name      | Type                | Params
--------------------------------------------------
0 | diffmodel | Unet1D              | 14.9 M
1 | diffusion | GaussianDiffusion1D | 14.9 M


train
train
train
Epoch 0:   0%|          | 0/15 [00:00<?, ?it/s] 

OutOfMemoryError: CUDA out of memory. Tried to allocate 78.00 MiB (GPU 0; 4.00 GiB total capacity; 3.48 GiB already allocated; 0 bytes free; 3.49 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [40]:
import torch
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D

model = Unet1D(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    channels = 1
)

diffusion = GaussianDiffusion1D(
    model,
    seq_length = 15910,
    timesteps = 1000,
    objective = 'pred_v'
)

training_seq = torch.rand(8, 1, 15910) # features are normalized from 0 to 1
loss = diffusion(training_seq)
loss.backward()
for (batch, _) in weights_data_loader:
    print(batch.shape)
    batch = batch.reshape(20,1,15910)
    batch.to(device)
    loss = diffusion(batch)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# after a lot of training

sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)

KeyboardInterrupt: 