In [1]:
import torch
import pennylane as qml
from time import time
from tqdm import tqdm
import numpy as np
import matplotlib as plt
from circuit_model_II import QuantumCircuit, FullQuantumModel
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import os

In [2]:
num_qubits = 8
num_layers = 5
model = FullQuantumModel(num_qubits, num_layers)

# Dataset preparation

In [3]:
# Download MNIST and prepare transforms
mnist_train = datasets.MNIST(root='./data', train=True, download=True,
                             transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
# Filter for zeros and ones
data = []
targets = []
for image, label in mnist_train:
    if label in [0, 1]:
        data.append(image.squeeze())
        targets.append(label)

data = torch.stack(data)
targets = torch.tensor(targets)
# Select 1024 zeros and 1024 ones for speed
zeros_indices = (targets == 0)
ones_indices = (targets == 1)

zeros = data[zeros_indices]
ones = data[ones_indices]

# take a subsample of the dataset for simplicity
zeros = zeros[:1024]
ones = ones[:1024]

#normalize between 0 and 1
zeros_max = torch.max(zeros.reshape(-1, 16*16), dim = 1)
zeros_min = torch.min(zeros.reshape(-1, 16*16), dim = 1)
ones_max = torch.max(ones.reshape(-1, 16*16), dim = 1)
ones_min = torch.min(ones.reshape(-1, 16*16), dim = 1)

def normalize(imgs):
  maxes, _ = torch.max(imgs.reshape(-1, 16*16), dim = 1)
  mins, _ = torch.min(imgs.reshape(-1, 16*16), dim = 1)

  mins = mins.unsqueeze(1).unsqueeze(2)
  maxes = maxes.unsqueeze(1).unsqueeze(2)

  return (imgs-mins)/(maxes-mins)

zeros = normalize(zeros)
ones = normalize(ones)

# assert images have min 0 and max 1 within an error of 1e-5
assert torch.allclose(zeros.min(), torch.tensor(0., dtype = torch.float32), atol=1e-5)
assert torch.allclose(zeros.max(), torch.tensor(1., dtype = torch.float32), atol=1e-5)
assert torch.allclose(ones.min(), torch.tensor(0., dtype = torch.float32), atol=1e-5)
assert torch.allclose(ones.max(), torch.tensor(1., dtype = torch.float32), atol=1e-5)

# concatenate the two datasets
zeros = zeros.flatten(start_dim = 1)
ones = ones.flatten(start_dim = 1)
dataset = torch.cat((zeros, ones), dim = 0)

# add labels
labels = torch.cat((torch.zeros((zeros.shape[0], 1)), torch.ones((ones.shape[0], 1))), dim = 0).squeeze()

# build dataloader
dataset = torch.utils.data.TensorDataset(dataset, labels)

BATCH_SIZE = 32

dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

In [4]:
dataloader.batch_size

32

# Model training

In [11]:
model.fit(dataloader=dataloader, learning_rate=0.01, epochs=10)

Epoch 1/10:   3%|▎         | 2/64 [00:00<00:03, 15.99it/s, accuracy=1, loss=0.302]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:   9%|▉         | 6/64 [00:00<00:03, 16.82it/s, accuracy=1, loss=0.318]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  16%|█▌        | 10/64 [00:00<00:03, 16.43it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  22%|██▏       | 14/64 [00:00<00:02, 17.87it/s, accuracy=0.969, loss=0.336]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  28%|██▊       | 18/64 [00:01<00:02, 18.80it/s, accuracy=1, loss=0.296]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  33%|███▎      | 21/64 [00:01<00:02, 19.03it/s, accuracy=0.969, loss=0.331]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  41%|████      | 26/64 [00:01<00:01, 19.28it/s, accuracy=0.969, loss=0.321]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  47%|████▋     | 30/64 [00:01<00:01, 19.20it/s, accuracy=1, loss=0.325]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  55%|█████▍    | 35/64 [00:01<00:01, 19.52it/s, accuracy=1, loss=0.312]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  62%|██████▎   | 40/64 [00:02<00:01, 19.86it/s, accuracy=1, loss=0.314]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  67%|██████▋   | 43/64 [00:02<00:01, 20.03it/s, accuracy=1, loss=0.325]   

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  73%|███████▎  | 47/64 [00:02<00:00, 19.72it/s, accuracy=1, loss=0.33]     

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  83%|████████▎ | 53/64 [00:02<00:00, 20.42it/s, accuracy=1, loss=0.314]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10:  92%|█████████▏| 59/64 [00:03<00:00, 20.56it/s, accuracy=1, loss=0.312]   

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 1/10: 100%|██████████| 64/64 [00:03<00:00, 19.26it/s, accuracy=1, loss=0.298]    


torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:   0%|          | 0/64 [00:00<?, ?it/s]

torch.Size([32, 256])


Epoch 2/10:   3%|▎         | 2/64 [00:00<00:03, 19.81it/s, accuracy=1, loss=0.337]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:   8%|▊         | 5/64 [00:00<00:02, 20.22it/s, accuracy=1, loss=0.329]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  12%|█▎        | 8/64 [00:00<00:02, 20.27it/s, accuracy=1, loss=0.291]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  12%|█▎        | 8/64 [00:00<00:02, 20.27it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  17%|█▋        | 11/64 [00:00<00:02, 20.12it/s, accuracy=1, loss=0.319]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  22%|██▏       | 14/64 [00:00<00:02, 20.19it/s, accuracy=1, loss=0.311]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  27%|██▋       | 17/64 [00:00<00:02, 20.30it/s, accuracy=0.969, loss=0.321]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  27%|██▋       | 17/64 [00:00<00:02, 20.30it/s, accuracy=1, loss=0.327]    

torch.Size([32, 256])


Epoch 2/10:  36%|███▌      | 23/64 [00:01<00:02, 20.16it/s, accuracy=0.969, loss=0.314]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  36%|███▌      | 23/64 [00:01<00:02, 20.16it/s, accuracy=1, loss=0.305]    

torch.Size([32, 256])


Epoch 2/10:  41%|████      | 26/64 [00:01<00:01, 20.27it/s, accuracy=0.969, loss=0.332]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  45%|████▌     | 29/64 [00:01<00:01, 20.49it/s, accuracy=1, loss=0.298]    

torch.Size([32, 256])


Epoch 2/10:  50%|█████     | 32/64 [00:01<00:01, 20.50it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  50%|█████     | 32/64 [00:01<00:01, 20.50it/s, accuracy=0.969, loss=0.316]

torch.Size([32, 256])


Epoch 2/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.61it/s, accuracy=1, loss=0.299]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.61it/s, accuracy=1, loss=0.323]

torch.Size([32, 256])


Epoch 2/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.57it/s, accuracy=1, loss=0.324]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.44it/s, accuracy=1, loss=0.311]

torch.Size([32, 256])


Epoch 2/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.44it/s, accuracy=1, loss=0.315]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  73%|███████▎  | 47/64 [00:02<00:01, 16.99it/s, accuracy=1, loss=0.305]

torch.Size([32, 256])


Epoch 2/10:  78%|███████▊  | 50/64 [00:02<00:00, 17.94it/s, accuracy=1, loss=0.307]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  78%|███████▊  | 50/64 [00:02<00:00, 17.94it/s, accuracy=0.969, loss=0.326]

torch.Size([32, 256])


Epoch 2/10:  86%|████████▌ | 55/64 [00:02<00:00, 18.66it/s, accuracy=1, loss=0.306]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  86%|████████▌ | 55/64 [00:02<00:00, 18.66it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])


Epoch 2/10:  94%|█████████▍| 60/64 [00:03<00:00, 19.64it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 2/10:  94%|█████████▍| 60/64 [00:03<00:00, 19.64it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])


Epoch 2/10: 100%|██████████| 64/64 [00:03<00:00, 19.68it/s, accuracy=1, loss=0.302]    


torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:   0%|          | 0/64 [00:00<?, ?it/s]

torch.Size([32, 256])


Epoch 3/10:   0%|          | 0/64 [00:00<?, ?it/s, accuracy=1, loss=0.314]

torch.Size([32, 256])


Epoch 3/10:   3%|▎         | 2/64 [00:00<00:03, 18.59it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:   6%|▋         | 4/64 [00:00<00:03, 18.97it/s, accuracy=1, loss=0.3]  

torch.Size([32, 256])


Epoch 3/10:   6%|▋         | 4/64 [00:00<00:03, 18.97it/s, accuracy=1, loss=0.321]

torch.Size([32, 256])


Epoch 3/10:  11%|█         | 7/64 [00:00<00:02, 19.77it/s, accuracy=1, loss=0.317]   

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  11%|█         | 7/64 [00:00<00:02, 19.77it/s, accuracy=1, loss=0.276]

torch.Size([32, 256])


Epoch 3/10:  16%|█▌        | 10/64 [00:00<00:02, 20.04it/s, accuracy=1, loss=0.296]

torch.Size([32, 256])


Epoch 3/10:  19%|█▉        | 12/64 [00:00<00:02, 20.01it/s, accuracy=1, loss=0.323]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  19%|█▉        | 12/64 [00:00<00:02, 20.01it/s, accuracy=1, loss=0.314]

torch.Size([32, 256])


Epoch 3/10:  22%|██▏       | 14/64 [00:00<00:02, 19.58it/s, accuracy=1, loss=0.309]

torch.Size([32, 256])


Epoch 3/10:  25%|██▌       | 16/64 [00:00<00:02, 19.69it/s, accuracy=1, loss=0.304]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  28%|██▊       | 18/64 [00:00<00:02, 19.71it/s, accuracy=1, loss=0.314]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  33%|███▎      | 21/64 [00:01<00:02, 20.00it/s, accuracy=1, loss=0.32] 

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  36%|███▌      | 23/64 [00:01<00:02, 19.93it/s, accuracy=1, loss=0.32] 

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  41%|████      | 26/64 [00:01<00:01, 20.20it/s, accuracy=1, loss=0.327]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  45%|████▌     | 29/64 [00:01<00:01, 20.18it/s, accuracy=0.969, loss=0.309]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  50%|█████     | 32/64 [00:01<00:01, 20.08it/s, accuracy=1, loss=0.318]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  50%|█████     | 32/64 [00:01<00:01, 20.08it/s, accuracy=1, loss=0.317]

torch.Size([32, 256])


Epoch 3/10:  55%|█████▍    | 35/64 [00:01<00:01, 19.94it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.30it/s, accuracy=1, loss=0.287]

torch.Size([32, 256])


Epoch 3/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.14it/s, accuracy=1, loss=0.321]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.14it/s, accuracy=1, loss=0.305]

torch.Size([32, 256])


Epoch 3/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.42it/s, accuracy=1, loss=0.34]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.42it/s, accuracy=0.969, loss=0.311]

torch.Size([32, 256])


Epoch 3/10:  78%|███████▊  | 50/64 [00:02<00:00, 20.30it/s, accuracy=1, loss=0.315]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  83%|████████▎ | 53/64 [00:02<00:00, 20.35it/s, accuracy=1, loss=0.316]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.41it/s, accuracy=1, loss=0.319]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.41it/s, accuracy=1, loss=0.287]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  92%|█████████▏| 59/64 [00:03<00:00, 20.50it/s, accuracy=1, loss=0.317]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10:  97%|█████████▋| 62/64 [00:03<00:00, 20.34it/s, accuracy=0.969, loss=0.317]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 3/10: 100%|██████████| 64/64 [00:03<00:00, 20.16it/s, accuracy=0.969, loss=0.326]
Epoch 4/10:   0%|          | 0/64 [00:00<?, ?it/s, accuracy=1, loss=0.313]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:   5%|▍         | 3/64 [00:00<00:02, 20.75it/s, accuracy=1, loss=0.328]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:   9%|▉         | 6/64 [00:00<00:02, 20.27it/s, accuracy=1, loss=0.292]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:   9%|▉         | 6/64 [00:00<00:02, 20.27it/s, accuracy=0.969, loss=0.333]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  14%|█▍        | 9/64 [00:00<00:02, 20.19it/s, accuracy=0.969, loss=0.315]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  19%|█▉        | 12/64 [00:00<00:02, 20.28it/s, accuracy=1, loss=0.32]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  23%|██▎       | 15/64 [00:00<00:02, 19.44it/s, accuracy=0.969, loss=0.307]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  28%|██▊       | 18/64 [00:00<00:02, 19.81it/s, accuracy=1, loss=0.304]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  28%|██▊       | 18/64 [00:00<00:02, 19.81it/s, accuracy=1, loss=0.302]  

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  33%|███▎      | 21/64 [00:01<00:02, 19.95it/s, accuracy=1, loss=0.309]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  38%|███▊      | 24/64 [00:01<00:01, 20.21it/s, accuracy=1, loss=0.305]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  42%|████▏     | 27/64 [00:01<00:01, 19.07it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  42%|████▏     | 27/64 [00:01<00:01, 19.07it/s, accuracy=0.969, loss=0.356]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  47%|████▋     | 30/64 [00:01<00:01, 19.42it/s, accuracy=0.938, loss=0.336]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  52%|█████▏    | 33/64 [00:01<00:01, 19.63it/s, accuracy=1, loss=0.326]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  56%|█████▋    | 36/64 [00:01<00:01, 19.99it/s, accuracy=1, loss=0.291]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  61%|██████    | 39/64 [00:01<00:01, 20.05it/s, accuracy=1, loss=0.319]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  66%|██████▌   | 42/64 [00:02<00:01, 19.96it/s, accuracy=1, loss=0.295]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  66%|██████▌   | 42/64 [00:02<00:01, 19.96it/s, accuracy=1, loss=0.317]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  70%|███████   | 45/64 [00:02<00:00, 20.14it/s, accuracy=1, loss=0.306]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  75%|███████▌  | 48/64 [00:02<00:00, 20.08it/s, accuracy=1, loss=0.307]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  80%|███████▉  | 51/64 [00:02<00:00, 20.19it/s, accuracy=1, loss=0.299]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  84%|████████▍ | 54/64 [00:02<00:00, 20.30it/s, accuracy=1, loss=0.295]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  84%|████████▍ | 54/64 [00:02<00:00, 20.30it/s, accuracy=1, loss=0.299]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  89%|████████▉ | 57/64 [00:02<00:00, 20.32it/s, accuracy=0.969, loss=0.332]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10:  94%|█████████▍| 60/64 [00:03<00:00, 20.33it/s, accuracy=0.969, loss=0.327]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 4/10: 100%|██████████| 64/64 [00:03<00:00, 20.05it/s, accuracy=1, loss=0.313]    


torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:   0%|          | 0/64 [00:00<?, ?it/s]

torch.Size([32, 256])


Epoch 5/10:   3%|▎         | 2/64 [00:00<00:03, 19.90it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:   3%|▎         | 2/64 [00:00<00:03, 19.90it/s, accuracy=0.969, loss=0.284]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:   8%|▊         | 5/64 [00:00<00:02, 19.72it/s, accuracy=1, loss=0.298]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  12%|█▎        | 8/64 [00:00<00:02, 20.07it/s, accuracy=1, loss=0.277]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  17%|█▋        | 11/64 [00:00<00:02, 20.25it/s, accuracy=0.969, loss=0.319]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  22%|██▏       | 14/64 [00:00<00:02, 20.43it/s, accuracy=1, loss=0.301]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  22%|██▏       | 14/64 [00:00<00:02, 20.43it/s, accuracy=1, loss=0.301]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  27%|██▋       | 17/64 [00:00<00:02, 20.23it/s, accuracy=1, loss=0.339]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  31%|███▏      | 20/64 [00:01<00:02, 20.30it/s, accuracy=1, loss=0.31] 

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  36%|███▌      | 23/64 [00:01<00:02, 20.31it/s, accuracy=1, loss=0.315]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  41%|████      | 26/64 [00:01<00:01, 20.22it/s, accuracy=1, loss=0.279]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  45%|████▌     | 29/64 [00:01<00:01, 20.31it/s, accuracy=1, loss=0.296]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  45%|████▌     | 29/64 [00:01<00:01, 20.31it/s, accuracy=1, loss=0.332]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  50%|█████     | 32/64 [00:01<00:01, 20.40it/s, accuracy=0.969, loss=0.317]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  55%|█████▍    | 35/64 [00:01<00:01, 20.60it/s, accuracy=0.969, loss=0.319]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.49it/s, accuracy=0.969, loss=0.32] 

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.45it/s, accuracy=0.969, loss=0.32]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.54it/s, accuracy=1, loss=0.309]   

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.54it/s, accuracy=1, loss=0.316]

torch.Size([32, 256])


Epoch 5/10:  73%|███████▎  | 47/64 [00:02<00:00, 18.41it/s, accuracy=0.969, loss=0.293]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  78%|███████▊  | 50/64 [00:02<00:00, 19.06it/s, accuracy=1, loss=0.319]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  83%|████████▎ | 53/64 [00:02<00:00, 19.49it/s, accuracy=1, loss=0.296]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  83%|████████▎ | 53/64 [00:02<00:00, 19.49it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  91%|█████████ | 58/64 [00:02<00:00, 19.74it/s, accuracy=1, loss=0.317]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  91%|█████████ | 58/64 [00:02<00:00, 19.74it/s, accuracy=0.969, loss=0.333]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10:  95%|█████████▌| 61/64 [00:03<00:00, 19.90it/s, accuracy=1, loss=0.334]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 5/10: 100%|██████████| 64/64 [00:03<00:00, 20.02it/s, accuracy=0.969, loss=0.305]
Epoch 6/10:   0%|          | 0/64 [00:00<?, ?it/s, accuracy=1, loss=0.305]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:   5%|▍         | 3/64 [00:00<00:03, 20.29it/s, accuracy=1, loss=0.324]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:   9%|▉         | 6/64 [00:00<00:02, 20.52it/s, accuracy=1, loss=0.295]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  14%|█▍        | 9/64 [00:00<00:02, 20.17it/s, accuracy=1, loss=0.324]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  14%|█▍        | 9/64 [00:00<00:02, 20.17it/s, accuracy=1, loss=0.313]

torch.Size([32, 256])


Epoch 6/10:  19%|█▉        | 12/64 [00:00<00:02, 19.81it/s, accuracy=0.969, loss=0.306]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  23%|██▎       | 15/64 [00:00<00:02, 19.92it/s, accuracy=1, loss=0.285]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  27%|██▋       | 17/64 [00:00<00:02, 19.91it/s, accuracy=0.969, loss=0.326]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  31%|███▏      | 20/64 [00:00<00:02, 20.20it/s, accuracy=1, loss=0.334]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  36%|███▌      | 23/64 [00:01<00:02, 20.20it/s, accuracy=1, loss=0.288]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  36%|███▌      | 23/64 [00:01<00:02, 20.20it/s, accuracy=1, loss=0.289]

torch.Size([32, 256])


Epoch 6/10:  41%|████      | 26/64 [00:01<00:01, 20.18it/s, accuracy=1, loss=0.303]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  45%|████▌     | 29/64 [00:01<00:01, 20.31it/s, accuracy=1, loss=0.313]

torch.Size([32, 256])


Epoch 6/10:  50%|█████     | 32/64 [00:01<00:01, 20.39it/s, accuracy=0.969, loss=0.338]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  50%|█████     | 32/64 [00:01<00:01, 20.39it/s, accuracy=1, loss=0.316]    

torch.Size([32, 256])


Epoch 6/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.33it/s, accuracy=0.938, loss=0.369]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  59%|█████▉    | 38/64 [00:01<00:01, 20.33it/s, accuracy=0.969, loss=0.332]

torch.Size([32, 256])


Epoch 6/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.42it/s, accuracy=1, loss=0.297]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.26it/s, accuracy=1, loss=0.309]

torch.Size([32, 256])


Epoch 6/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.20it/s, accuracy=1, loss=0.304]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.20it/s, accuracy=0.969, loss=0.338]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  78%|███████▊  | 50/64 [00:02<00:00, 20.10it/s, accuracy=0.969, loss=0.333]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  83%|████████▎ | 53/64 [00:02<00:00, 20.29it/s, accuracy=1, loss=0.316]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.16it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.16it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  92%|█████████▏| 59/64 [00:03<00:00, 20.17it/s, accuracy=1, loss=0.29] 

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10:  97%|█████████▋| 62/64 [00:03<00:00, 20.22it/s, accuracy=1, loss=0.312]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 6/10: 100%|██████████| 64/64 [00:03<00:00, 20.22it/s, accuracy=1, loss=0.313]
Epoch 7/10:   3%|▎         | 2/64 [00:00<00:03, 19.67it/s, accuracy=0.969, loss=0.326]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:   3%|▎         | 2/64 [00:00<00:03, 19.67it/s, accuracy=1, loss=0.314]    

torch.Size([32, 256])


Epoch 7/10:  11%|█         | 7/64 [00:00<00:02, 19.86it/s, accuracy=0.969, loss=0.321]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  11%|█         | 7/64 [00:00<00:02, 19.86it/s, accuracy=1, loss=0.33]     

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  16%|█▌        | 10/64 [00:00<00:02, 19.96it/s, accuracy=1, loss=0.289]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  19%|█▉        | 12/64 [00:00<00:02, 19.93it/s, accuracy=1, loss=0.337]

torch.Size([32, 256])


Epoch 7/10:  23%|██▎       | 15/64 [00:00<00:02, 20.24it/s, accuracy=1, loss=0.332]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  23%|██▎       | 15/64 [00:00<00:02, 20.24it/s, accuracy=1, loss=0.304]

torch.Size([32, 256])


Epoch 7/10:  33%|███▎      | 21/64 [00:01<00:02, 20.51it/s, accuracy=1, loss=0.322]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  33%|███▎      | 21/64 [00:01<00:02, 20.51it/s, accuracy=1, loss=0.286]

torch.Size([32, 256])


Epoch 7/10:  38%|███▊      | 24/64 [00:01<00:01, 20.35it/s, accuracy=1, loss=0.328]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  42%|████▏     | 27/64 [00:01<00:01, 20.30it/s, accuracy=1, loss=0.307]

torch.Size([32, 256])


Epoch 7/10:  47%|████▋     | 30/64 [00:01<00:01, 19.96it/s, accuracy=1, loss=0.316]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  47%|████▋     | 30/64 [00:01<00:01, 19.96it/s, accuracy=1, loss=0.298]

torch.Size([32, 256])


Epoch 7/10:  55%|█████▍    | 35/64 [00:01<00:01, 19.84it/s, accuracy=1, loss=0.303]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  55%|█████▍    | 35/64 [00:01<00:01, 19.84it/s, accuracy=1, loss=0.313]

torch.Size([32, 256])


Epoch 7/10:  62%|██████▎   | 40/64 [00:01<00:01, 19.99it/s, accuracy=1, loss=0.297]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  62%|██████▎   | 40/64 [00:02<00:01, 19.99it/s, accuracy=1, loss=0.328]

torch.Size([32, 256])


Epoch 7/10:  67%|██████▋   | 43/64 [00:02<00:01, 20.19it/s, accuracy=1, loss=0.314]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  72%|███████▏  | 46/64 [00:02<00:00, 20.24it/s, accuracy=1, loss=0.293]

torch.Size([32, 256])


Epoch 7/10:  77%|███████▋  | 49/64 [00:02<00:00, 20.21it/s, accuracy=0.938, loss=0.325]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  77%|███████▋  | 49/64 [00:02<00:00, 20.21it/s, accuracy=1, loss=0.305]    

torch.Size([32, 256])


Epoch 7/10:  81%|████████▏ | 52/64 [00:02<00:00, 20.25it/s, accuracy=1, loss=0.333]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  86%|████████▌ | 55/64 [00:02<00:00, 20.02it/s, accuracy=1, loss=0.313]

torch.Size([32, 256])


Epoch 7/10:  86%|████████▌ | 55/64 [00:02<00:00, 20.02it/s, accuracy=1, loss=0.347]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  91%|█████████ | 58/64 [00:02<00:00, 18.01it/s, accuracy=1, loss=0.293]

torch.Size([32, 256])


Epoch 7/10:  95%|█████████▌| 61/64 [00:03<00:00, 18.72it/s, accuracy=1, loss=0.306]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 7/10:  95%|█████████▌| 61/64 [00:03<00:00, 18.72it/s, accuracy=0.969, loss=0.309]

torch.Size([32, 256])


Epoch 7/10: 100%|██████████| 64/64 [00:03<00:00, 19.77it/s, accuracy=1, loss=0.298]    
Epoch 8/10:   5%|▍         | 3/64 [00:00<00:03, 20.07it/s, accuracy=1, loss=0.293]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:   5%|▍         | 3/64 [00:00<00:03, 20.07it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:   9%|▉         | 6/64 [00:00<00:02, 20.66it/s, accuracy=1, loss=0.306]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  14%|█▍        | 9/64 [00:00<00:02, 20.56it/s, accuracy=1, loss=0.311]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  19%|█▉        | 12/64 [00:00<00:02, 20.28it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  19%|█▉        | 12/64 [00:00<00:02, 20.28it/s, accuracy=0.969, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  23%|██▎       | 15/64 [00:00<00:02, 20.31it/s, accuracy=1, loss=0.302]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  28%|██▊       | 18/64 [00:00<00:02, 20.51it/s, accuracy=1, loss=0.306]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  33%|███▎      | 21/64 [00:01<00:02, 20.40it/s, accuracy=1, loss=0.297]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  38%|███▊      | 24/64 [00:01<00:01, 20.37it/s, accuracy=1, loss=0.316]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  38%|███▊      | 24/64 [00:01<00:01, 20.37it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  42%|████▏     | 27/64 [00:01<00:01, 20.39it/s, accuracy=1, loss=0.324]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  47%|████▋     | 30/64 [00:01<00:01, 20.36it/s, accuracy=1, loss=0.293]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  47%|████▋     | 30/64 [00:01<00:01, 20.36it/s, accuracy=0.969, loss=0.338]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  52%|█████▏    | 33/64 [00:01<00:01, 20.04it/s, accuracy=1, loss=0.304]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  56%|█████▋    | 36/64 [00:01<00:01, 19.94it/s, accuracy=1, loss=0.333]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  59%|█████▉    | 38/64 [00:01<00:01, 19.90it/s, accuracy=1, loss=0.327]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.05it/s, accuracy=1, loss=0.298]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.03it/s, accuracy=0.938, loss=0.343]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.03it/s, accuracy=1, loss=0.308]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.10it/s, accuracy=0.969, loss=0.335]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  78%|███████▊  | 50/64 [00:02<00:00, 20.18it/s, accuracy=1, loss=0.308]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  83%|████████▎ | 53/64 [00:02<00:00, 20.37it/s, accuracy=1, loss=0.303]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.18it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  92%|█████████▏| 59/64 [00:02<00:00, 20.26it/s, accuracy=1, loss=0.336]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10:  92%|█████████▏| 59/64 [00:03<00:00, 20.26it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 8/10: 100%|██████████| 64/64 [00:03<00:00, 20.25it/s, accuracy=1, loss=0.292]


torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:   0%|          | 0/64 [00:00<?, ?it/s]

torch.Size([32, 256])


Epoch 9/10:   3%|▎         | 2/64 [00:00<00:03, 19.83it/s, accuracy=1, loss=0.323]

torch.Size([32, 256])


Epoch 9/10:   3%|▎         | 2/64 [00:00<00:03, 19.83it/s, accuracy=0.969, loss=0.325]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:   8%|▊         | 5/64 [00:00<00:02, 19.94it/s, accuracy=1, loss=0.297]    

torch.Size([32, 256])


Epoch 9/10:   8%|▊         | 5/64 [00:00<00:02, 19.94it/s, accuracy=1, loss=0.297]

torch.Size([32, 256])


Epoch 9/10:  12%|█▎        | 8/64 [00:00<00:02, 20.12it/s, accuracy=0.969, loss=0.344]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  12%|█▎        | 8/64 [00:00<00:02, 20.12it/s, accuracy=1, loss=0.324]    

torch.Size([32, 256])


Epoch 9/10:  17%|█▋        | 11/64 [00:00<00:02, 20.20it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])


Epoch 9/10:  22%|██▏       | 14/64 [00:00<00:02, 20.17it/s, accuracy=0.969, loss=0.313]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  22%|██▏       | 14/64 [00:00<00:02, 20.17it/s, accuracy=1, loss=0.302]    

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  27%|██▋       | 17/64 [00:00<00:02, 20.12it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  31%|███▏      | 20/64 [00:01<00:02, 20.06it/s, accuracy=1, loss=0.304]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  36%|███▌      | 23/64 [00:01<00:02, 20.01it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  41%|████      | 26/64 [00:01<00:01, 20.15it/s, accuracy=1, loss=0.326]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  41%|████      | 26/64 [00:01<00:01, 20.15it/s, accuracy=1, loss=0.311]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  45%|████▌     | 29/64 [00:01<00:01, 20.30it/s, accuracy=0.969, loss=0.336]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  50%|█████     | 32/64 [00:01<00:01, 20.20it/s, accuracy=1, loss=0.326]    

torch.Size([32, 256])


Epoch 9/10:  55%|█████▍    | 35/64 [00:01<00:01, 20.19it/s, accuracy=0.969, loss=0.337]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  55%|█████▍    | 35/64 [00:01<00:01, 20.19it/s, accuracy=1, loss=0.305]    

torch.Size([32, 256])


Epoch 9/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.23it/s, accuracy=1, loss=0.345]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  64%|██████▍   | 41/64 [00:02<00:01, 20.23it/s, accuracy=1, loss=0.327]

torch.Size([32, 256])


Epoch 9/10:  69%|██████▉   | 44/64 [00:02<00:00, 20.02it/s, accuracy=1, loss=0.302]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  73%|███████▎  | 47/64 [00:02<00:00, 20.18it/s, accuracy=0.969, loss=0.319]

torch.Size([32, 256])


Epoch 9/10:  78%|███████▊  | 50/64 [00:02<00:00, 20.07it/s, accuracy=1, loss=0.306]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  88%|████████▊ | 56/64 [00:02<00:00, 20.20it/s, accuracy=1, loss=0.308]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10:  92%|█████████▏| 59/64 [00:03<00:00, 20.35it/s, accuracy=1, loss=0.306]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 9/10: 100%|██████████| 64/64 [00:03<00:00, 20.19it/s, accuracy=0.969, loss=0.314]


torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:   0%|          | 0/64 [00:00<?, ?it/s, accuracy=1, loss=0.318]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:   3%|▎         | 2/64 [00:00<00:03, 18.96it/s, accuracy=1, loss=0.294]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:   6%|▋         | 4/64 [00:00<00:03, 19.51it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  11%|█         | 7/64 [00:00<00:02, 19.93it/s, accuracy=1, loss=0.288]

torch.Size([32, 256])


Epoch 10/10:  14%|█▍        | 9/64 [00:00<00:02, 19.06it/s, accuracy=0.969, loss=0.349]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  17%|█▋        | 11/64 [00:00<00:02, 18.71it/s, accuracy=1, loss=0.294]   

torch.Size([32, 256])


Epoch 10/10:  20%|██        | 13/64 [00:00<00:02, 18.26it/s, accuracy=0.969, loss=0.327]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  23%|██▎       | 15/64 [00:00<00:02, 16.54it/s, accuracy=1, loss=0.305]    

torch.Size([32, 256])


Epoch 10/10:  27%|██▋       | 17/64 [00:00<00:02, 17.20it/s, accuracy=1, loss=0.309]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  27%|██▋       | 17/64 [00:01<00:02, 17.20it/s, accuracy=1, loss=0.303]

torch.Size([32, 256])


Epoch 10/10:  31%|███▏      | 20/64 [00:01<00:02, 18.18it/s, accuracy=1, loss=0.323]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  36%|███▌      | 23/64 [00:01<00:02, 18.82it/s, accuracy=0.969, loss=0.324]

torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  41%|████      | 26/64 [00:01<00:01, 19.26it/s, accuracy=1, loss=0.315]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  44%|████▍     | 28/64 [00:01<00:01, 19.16it/s, accuracy=1, loss=0.325]

torch.Size([32, 256])


Epoch 10/10:  47%|████▋     | 30/64 [00:01<00:01, 19.31it/s, accuracy=1, loss=0.312]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  56%|█████▋    | 36/64 [00:01<00:01, 19.78it/s, accuracy=1, loss=0.307]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  61%|██████    | 39/64 [00:02<00:01, 20.09it/s, accuracy=1, loss=0.32] 

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  70%|███████   | 45/64 [00:02<00:01, 18.58it/s, accuracy=1, loss=0.32]     

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  78%|███████▊  | 50/64 [00:02<00:00, 19.40it/s, accuracy=0.969, loss=0.322]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  83%|████████▎ | 53/64 [00:02<00:00, 19.62it/s, accuracy=1, loss=0.303]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10:  92%|█████████▏| 59/64 [00:03<00:00, 19.98it/s, accuracy=1, loss=0.334]    

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])


Epoch 10/10: 100%|██████████| 64/64 [00:03<00:00, 19.14it/s, accuracy=0.969, loss=0.289]

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 256])
Time per epoch:  3.3440439701080322
Epoch:  9 Loss:  0.2892584502696991
Accuracy:  0.96875
--------------------------------------------------------------------------





(3.222240614891052,
 [0.29758214950561523,
  0.30170902609825134,
  0.32587286829948425,
  0.31278491020202637,
  0.3051925599575043,
  0.3128332197666168,
  0.29793015122413635,
  0.2923735976219177,
  0.31433677673339844,
  0.2892584502696991])

In [12]:
for i in model.parameters(): 
    print(i)
    break

Parameter containing:
tensor([[ 1.4929,  0.9107,  1.0181],
        [ 0.9554, -0.5808, -0.2439],
        [ 0.5648, -0.2558,  1.7541],
        [ 1.4368,  1.1390, -0.0140],
        [ 0.8342, -0.1917,  1.8964],
        [ 0.5762,  0.1116,  1.4495],
        [ 0.7821,  0.0411,  0.5196],
        [ 0.5933,  0.8271,  1.3459]], requires_grad=True)


In [10]:
model.unfreeze_layers([1,2])
for i in model.parameters(): 
    print(i)
    break

Parameter containing:
tensor([[ 1.3629,  1.0201,  0.8152],
        [ 0.8423, -0.4657, -0.1720],
        [ 0.8119, -0.0925,  1.7360],
        [ 1.4267,  0.9895,  0.0350],
        [ 0.8196, -0.1364,  1.5823],
        [ 0.4226,  0.2861,  1.4179],
        [ 0.9527,  0.1969,  0.5144],
        [ 0.3481,  0.5851,  1.3831]], requires_grad=True)


NameError: name 'norm' is not defined