In [2]:
%load_ext autoreload
%autoreload 2

%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [3]:
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.models as models
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Check for CUDA support
print(f"Is CUDA available? {torch.cuda.is_available()}")

# Set the device
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")


PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Is CUDA available? False
Using device: mps


In [5]:
resnet50 = models.resnet50(pretrained=True)
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, 10)

# # freeze all layers
# for param in resnet50.parameters():
#     param.requires_grad = False



In [6]:
from EarlyExitModel import EarlyExitModel

model = EarlyExitModel(resnet50, 10, device)
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
     

In [7]:
model.clear_exits()
exit_layers = [model.add_exit(layer) for layer in ('layer1', 'layer2', 'layer3')]
model.to(device)  # Move the model to the selected device
model

EarlyExitModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): OptionalExitModule(
      (module): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(i

In [8]:
from DataLoader import CustomDataset

hf_dataset = load_dataset("frgfm/imagenette", '320px')
hf_dataset = concatenate_datasets(hf_dataset.values())

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


torch_dataset = CustomDataset(hf_dataset, transform=transform)

batch_size = 32

test_size = 0.2
test_volume = int(test_size * len(torch_dataset))
train_volume = len(torch_dataset) - test_volume

train_dataset, test_dataset = random_split(torch_dataset, [train_volume, test_volume])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False, 
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

## Early Exit Model Training

In [9]:
from EarlyExitTrainer import ModelTrainer

trainer = ModelTrainer(model, device)

trainer.train(train_dataloader, epoch_count=1, validation_loader=test_dataloader)


Training early exit layer 1
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 0.12589962091018905
Epoch 0 Accuracy 0.0016791044776119403




Epoch 0 Validation Loss 3.9433146033968245
Epoch 0 Validation Accuracy 0.10443722943722944
Training early exit layer 2
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 0.16908268145660857
Epoch 0 Accuracy 0.0018656716417910447




Epoch 0 Validation Loss 11.77011867931911
Epoch 0 Validation Accuracy 0.09949945887445887
Training early exit layer 3
Beginning epoch 0


                                                                                                    

Epoch 0 Loss 0.1675518220929957
Epoch 0 Accuracy 0.0026119402985074628




Epoch 0 Validation Loss 9.445267569451104
Epoch 0 Validation Accuracy 0.10190070346320347
Beginning epoch 0 on final classifier head


                                                                                                    

Epoch 0 Loss 0.03291384853533844
Epoch 0 Accuracy 0.008488805970149253




Epoch 0 Validation Loss 1.0574619734571093
Epoch 0 Validation Accuracy 0.8047889610389611
Beginning epoch 0 with no forced exits


Epoch 0:   0%|                                                              | 0/335 [00:00<?, ?it/s]

[tensor([ 2,  3,  8, 10, 11, 12, 13, 16, 17, 18, 20, 21, 22, 25, 27, 29, 30, 31],
       device='mps:0'), tensor([ 0,  5,  6,  7,  9, 14, 19, 28], device='mps:0'), tensor([23, 26], device='mps:0'), tensor([ 1,  4, 15, 24], device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   0%|                                      | 1/335 [00:10<56:34, 10.16s/it, Accuracy=0.219]

[None, None, tensor([ 5,  7, 13, 22, 24, 25], device='mps:0'), tensor([ 0,  1,  2,  3,  4,  6,  8,  9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20,
        21, 23, 26, 27, 28, 29, 30, 31], device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   1%|▍                                         | 3/335 [00:15<22:02,  3.98s/it, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   1%|▌                                         | 4/335 [00:15<13:54,  2.52s/it, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   1%|▋                                         | 5/335 [00:15<09:19,  1.69s/it, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   2%|▊                                         | 6/335 [00:15<06:35,  1.20s/it, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   2%|▊                                     | 7/335 [00:16<04:51,  1.13it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   2%|▉                                     | 8/335 [00:16<03:45,  1.45it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   3%|█                                     | 9/335 [00:16<03:06,  1.75it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   3%|█                                    | 10/335 [00:16<02:39,  2.04it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   3%|█▏                                   | 11/335 [00:17<02:19,  2.32it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   4%|█▎                                   | 12/335 [00:17<02:05,  2.57it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   4%|█▍                                   | 13/335 [00:17<01:55,  2.79it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   4%|█▌                                   | 14/335 [00:18<01:49,  2.93it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   4%|█▋                                   | 15/335 [00:18<01:46,  3.01it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   5%|█▊                                    | 16/335 [00:18<01:43,  3.08it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   5%|█▉                                   | 17/335 [00:19<01:41,  3.14it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   5%|█▉                                   | 18/335 [00:19<01:39,  3.20it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   6%|██                                   | 19/335 [00:19<01:36,  3.26it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   6%|██▏                                  | 20/335 [00:19<01:35,  3.31it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   6%|██▎                                  | 21/335 [00:20<01:33,  3.35it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   7%|██▍                                  | 22/335 [00:20<01:32,  3.40it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   7%|██▌                                  | 23/335 [00:20<01:31,  3.42it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   7%|██▋                                  | 24/335 [00:21<01:31,  3.40it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   7%|██▊                                   | 25/335 [00:21<01:34,  3.29it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   8%|██▊                                  | 26/335 [00:21<01:33,  3.30it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   8%|██▉                                  | 27/335 [00:22<01:32,  3.34it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   8%|███                                  | 28/335 [00:22<01:31,  3.36it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   9%|███▏                                 | 29/335 [00:22<01:31,  3.34it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   9%|███▎                                 | 30/335 [00:22<01:32,  3.31it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:   9%|███▍                                 | 31/335 [00:23<01:35,  3.19it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  10%|███▋                                  | 32/335 [00:23<01:41,  2.97it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  10%|███▋                                 | 33/335 [00:23<01:40,  3.02it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  10%|███▊                                 | 34/335 [00:24<01:38,  3.04it/s, Accuracy=0.594]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  10%|███▊                                 | 35/335 [00:24<01:38,  3.06it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  11%|███▉                                 | 36/335 [00:24<01:38,  3.04it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  11%|████                                 | 37/335 [00:25<01:36,  3.09it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  11%|████▏                                | 38/335 [00:25<01:34,  3.16it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  12%|████▎                                | 39/335 [00:25<01:31,  3.24it/s, Accuracy=0.969]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  12%|████▍                                | 40/335 [00:26<01:29,  3.29it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  12%|████▌                                | 41/335 [00:26<01:27,  3.35it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  13%|████▋                                | 42/335 [00:26<01:26,  3.38it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  13%|████▋                                | 43/335 [00:27<01:27,  3.33it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  13%|████▊                                | 44/335 [00:27<01:29,  3.26it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  13%|████▉                                | 45/335 [00:27<01:30,  3.22it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  14%|█████                                | 46/335 [00:28<01:30,  3.20it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  14%|█████▏                               | 47/335 [00:28<01:29,  3.21it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  14%|█████▎                               | 48/335 [00:28<01:30,  3.18it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  15%|█████▍                               | 49/335 [00:28<01:30,  3.16it/s, Accuracy=0.969]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  15%|█████▋                                | 50/335 [00:29<01:29,  3.17it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  15%|█████▊                                | 51/335 [00:29<01:28,  3.19it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  16%|█████▋                               | 52/335 [00:29<01:28,  3.18it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  16%|██████                                | 53/335 [00:30<01:28,  3.20it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  16%|█████▉                               | 54/335 [00:30<01:27,  3.20it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  16%|██████                               | 55/335 [00:30<01:27,  3.19it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  17%|██████▏                              | 56/335 [00:31<01:27,  3.20it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  17%|██████▎                              | 57/335 [00:31<01:26,  3.21it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  17%|██████▍                              | 58/335 [00:31<01:26,  3.20it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  18%|██████▌                              | 59/335 [00:32<01:26,  3.20it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  18%|██████▋                              | 60/335 [00:32<01:27,  3.13it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  18%|██████▋                              | 61/335 [00:32<01:28,  3.09it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  19%|██████▊                              | 62/335 [00:33<01:27,  3.12it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  19%|██████▉                              | 63/335 [00:33<01:27,  3.12it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  19%|███████                              | 64/335 [00:33<01:26,  3.12it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  19%|███████▏                             | 65/335 [00:34<01:25,  3.15it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  20%|███████▎                             | 66/335 [00:34<01:24,  3.17it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  20%|███████▍                             | 67/335 [00:34<01:23,  3.22it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  20%|███████▌                             | 68/335 [00:34<01:22,  3.24it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  21%|███████▌                             | 69/335 [00:35<01:22,  3.23it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  21%|███████▋                             | 70/335 [00:35<01:22,  3.21it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  21%|███████▊                             | 71/335 [00:35<01:22,  3.21it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  21%|███████▉                             | 72/335 [00:36<01:22,  3.21it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  22%|████████                             | 73/335 [00:36<01:22,  3.19it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  22%|████████▍                             | 74/335 [00:36<01:24,  3.08it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  22%|████████▌                             | 75/335 [00:37<01:24,  3.08it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  23%|████████▍                            | 76/335 [00:37<01:23,  3.12it/s, Accuracy=0.656]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  23%|████████▌                            | 77/335 [00:37<01:21,  3.16it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  23%|████████▌                            | 78/335 [00:38<01:20,  3.19it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  24%|████████▋                            | 79/335 [00:38<01:19,  3.20it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  24%|████████▊                            | 80/335 [00:38<01:19,  3.20it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  24%|████████▉                            | 81/335 [00:39<01:19,  3.21it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  24%|█████████                            | 82/335 [00:39<01:19,  3.20it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  25%|█████████▏                           | 83/335 [00:39<01:18,  3.21it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  25%|█████████▎                           | 84/335 [00:39<01:17,  3.25it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  25%|█████████▍                           | 85/335 [00:40<01:16,  3.28it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  26%|█████████▍                           | 86/335 [00:40<01:15,  3.29it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  26%|█████████▌                           | 87/335 [00:40<01:15,  3.27it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  26%|█████████▋                           | 88/335 [00:41<01:15,  3.26it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  27%|█████████▊                           | 89/335 [00:41<01:16,  3.22it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  27%|█████████▉                           | 90/335 [00:41<01:15,  3.23it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  27%|██████████                           | 91/335 [00:42<01:15,  3.24it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  27%|██████████▏                          | 92/335 [00:42<01:14,  3.25it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  28%|██████████▎                          | 93/335 [00:42<01:14,  3.26it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  28%|██████████▍                          | 94/335 [00:43<01:14,  3.26it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  28%|██████████▍                          | 95/335 [00:43<01:13,  3.25it/s, Accuracy=0.969]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  29%|██████████▌                          | 96/335 [00:43<01:13,  3.26it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  29%|██████████▋                          | 97/335 [00:43<01:12,  3.26it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  29%|███████████▉                             | 98/335 [00:44<01:12,  3.28it/s, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  30%|██████████▉                          | 99/335 [00:44<01:12,  3.28it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  30%|███████████                          | 100/335 [00:44<01:11,  3.28it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  30%|██████████▊                         | 101/335 [00:45<01:11,  3.26it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  30%|██████████▉                         | 102/335 [00:45<01:11,  3.27it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  31%|███████████                         | 103/335 [00:45<01:11,  3.26it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  31%|███████████▏                        | 104/335 [00:46<01:11,  3.23it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  31%|███████████▎                        | 105/335 [00:46<01:10,  3.24it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  32%|███████████▍                        | 106/335 [00:46<01:10,  3.25it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  32%|███████████▍                        | 107/335 [00:47<01:10,  3.24it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  32%|███████████▉                         | 108/335 [00:47<01:09,  3.26it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  33%|███████████▋                        | 109/335 [00:47<01:09,  3.27it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  33%|███████████▊                        | 110/335 [00:47<01:09,  3.25it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  33%|███████████▉                        | 111/335 [00:48<01:08,  3.26it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  33%|████████████                        | 112/335 [00:48<01:08,  3.25it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  34%|████████████▏                       | 113/335 [00:48<01:09,  3.21it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  34%|████████████▌                        | 114/335 [00:49<01:08,  3.21it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  34%|█████████████▋                          | 115/335 [00:49<01:08,  3.20it/s, Accuracy=1]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  35%|████████████▊                        | 116/335 [00:49<01:09,  3.16it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  35%|████████████▌                       | 117/335 [00:50<01:09,  3.14it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  35%|████████████▋                       | 118/335 [00:50<01:08,  3.16it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  36%|████████████▊                       | 119/335 [00:50<01:07,  3.19it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  36%|████████████▉                       | 120/335 [00:51<01:08,  3.15it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  36%|█████████████                       | 121/335 [00:51<01:08,  3.14it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  36%|█████████████                       | 122/335 [00:51<01:06,  3.19it/s, Accuracy=0.656]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  37%|█████████████▏                      | 123/335 [00:52<01:05,  3.22it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  37%|█████████████▎                      | 124/335 [00:52<01:05,  3.21it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  37%|█████████████▍                      | 125/335 [00:52<01:05,  3.21it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  38%|█████████████▌                      | 126/335 [00:52<01:05,  3.18it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  38%|█████████████▋                      | 127/335 [00:53<01:06,  3.13it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  38%|█████████████▊                      | 128/335 [00:53<01:05,  3.16it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  39%|█████████████▊                      | 129/335 [00:53<01:05,  3.14it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  39%|█████████████▉                      | 130/335 [00:54<01:06,  3.10it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  39%|██████████████                      | 131/335 [00:54<01:05,  3.13it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  39%|██████████████▏                     | 132/335 [00:54<01:05,  3.09it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  40%|██████████████▎                     | 133/335 [00:55<01:08,  2.97it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  40%|██████████████▍                     | 134/335 [00:55<01:06,  3.01it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  40%|██████████████▌                     | 135/335 [00:55<01:05,  3.04it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  41%|██████████████▌                     | 136/335 [00:56<01:04,  3.10it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  41%|██████████████▋                     | 137/335 [00:56<01:03,  3.14it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  41%|██████████████▊                     | 138/335 [00:56<01:02,  3.16it/s, Accuracy=0.969]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  41%|██████████████▉                     | 139/335 [00:57<01:01,  3.17it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  42%|███████████████                     | 140/335 [00:57<01:01,  3.19it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  42%|███████████████▏                    | 141/335 [00:57<01:01,  3.14it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  42%|███████████████▎                    | 142/335 [00:58<01:01,  3.13it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  43%|███████████████▎                    | 143/335 [00:58<01:01,  3.13it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  43%|███████████████▍                    | 144/335 [00:58<01:00,  3.15it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  43%|███████████████▌                    | 145/335 [00:59<01:00,  3.15it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  44%|███████████████▋                    | 146/335 [00:59<00:59,  3.17it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  44%|███████████████▊                    | 147/335 [00:59<00:59,  3.15it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  44%|███████████████▉                    | 148/335 [01:00<00:59,  3.16it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  44%|████████████████                    | 149/335 [01:00<00:58,  3.20it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  45%|████████████████▌                    | 150/335 [01:00<00:57,  3.22it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  45%|████████████████▏                   | 151/335 [01:00<00:57,  3.21it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  45%|████████████████▎                   | 152/335 [01:01<00:57,  3.19it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  46%|████████████████▍                   | 153/335 [01:01<00:57,  3.16it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  46%|████████████████▌                   | 154/335 [01:01<00:57,  3.16it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  46%|████████████████▋                   | 155/335 [01:02<00:56,  3.16it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  47%|████████████████▊                   | 156/335 [01:02<00:56,  3.16it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  47%|████████████████▊                   | 157/335 [01:02<00:56,  3.16it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  47%|████████████████▉                   | 158/335 [01:03<00:55,  3.17it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  47%|█████████████████                   | 159/335 [01:03<00:55,  3.19it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  48%|█████████████████▏                  | 160/335 [01:03<00:55,  3.14it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  48%|█████████████████▎                  | 161/335 [01:04<00:56,  3.11it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  48%|█████████████████▍                  | 162/335 [01:04<00:55,  3.10it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  49%|██████████████████                   | 163/335 [01:04<00:54,  3.15it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  49%|█████████████████▌                  | 164/335 [01:05<00:53,  3.19it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  49%|█████████████████▋                  | 165/335 [01:05<00:52,  3.22it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  50%|█████████████████▊                  | 166/335 [01:05<00:52,  3.20it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  50%|██████████████████▍                  | 167/335 [01:06<00:53,  3.14it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  50%|██████████████████                  | 168/335 [01:06<00:53,  3.11it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  50%|██████████████████▏                 | 169/335 [01:06<00:53,  3.07it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  51%|██████████████████▎                 | 170/335 [01:06<00:53,  3.07it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  51%|██████████████████▉                  | 171/335 [01:07<00:52,  3.13it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  51%|██████████████████▍                 | 172/335 [01:07<00:52,  3.09it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  52%|██████████████████▌                 | 173/335 [01:07<00:54,  2.99it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  52%|██████████████████▋                 | 174/335 [01:08<00:52,  3.05it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  52%|██████████████████▊                 | 175/335 [01:08<00:52,  3.04it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  53%|██████████████████▉                 | 176/335 [01:08<00:52,  3.04it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  53%|███████████████████                 | 177/335 [01:09<00:51,  3.08it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  53%|███████████████████▏                | 178/335 [01:09<00:49,  3.16it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  53%|███████████████████▏                | 179/335 [01:09<00:48,  3.20it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  54%|███████████████████▎                | 180/335 [01:10<00:47,  3.24it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  54%|███████████████████▍                | 181/335 [01:10<00:47,  3.25it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  54%|███████████████████▌                | 182/335 [01:10<00:47,  3.22it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  55%|███████████████████▋                | 183/335 [01:11<00:47,  3.22it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  55%|████████████████████▎                | 184/335 [01:11<00:47,  3.20it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  55%|███████████████████▉                | 185/335 [01:11<00:47,  3.14it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  56%|███████████████████▉                | 186/335 [01:12<00:48,  3.07it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  56%|████████████████████                | 187/335 [01:12<00:48,  3.07it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  56%|████████████████████▏               | 188/335 [01:12<00:47,  3.09it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  56%|████████████████████▎               | 189/335 [01:13<00:46,  3.14it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  57%|████████████████████▍               | 190/335 [01:13<00:45,  3.19it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  57%|████████████████████▌               | 191/335 [01:13<00:44,  3.20it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  57%|████████████████████▋               | 192/335 [01:14<00:46,  3.11it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  58%|████████████████████▋               | 193/335 [01:14<00:46,  3.03it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  58%|████████████████████▊               | 194/335 [01:14<00:46,  3.05it/s, Accuracy=0.625]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  58%|████████████████████▉               | 195/335 [01:14<00:45,  3.11it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  59%|█████████████████████               | 196/335 [01:15<00:44,  3.16it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  59%|█████████████████████▏              | 197/335 [01:15<00:43,  3.16it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  59%|█████████████████████▎              | 198/335 [01:15<00:43,  3.13it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  59%|█████████████████████▉               | 199/335 [01:16<00:43,  3.14it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  60%|█████████████████████▍              | 200/335 [01:16<00:42,  3.18it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  60%|█████████████████████▌              | 201/335 [01:16<00:41,  3.19it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  60%|█████████████████████▋              | 202/335 [01:17<00:41,  3.23it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  61%|█████████████████████▊              | 203/335 [01:17<00:40,  3.23it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  61%|█████████████████████▉              | 204/335 [01:17<00:41,  3.19it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  61%|██████████████████████              | 205/335 [01:18<00:40,  3.21it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  61%|██████████████████████▏             | 206/335 [01:18<00:39,  3.23it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  62%|██████████████████████▏             | 207/335 [01:18<00:39,  3.25it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  62%|██████████████████████▉              | 208/335 [01:19<00:39,  3.22it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  62%|██████████████████████▍             | 209/335 [01:19<00:38,  3.25it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  63%|███████████████████████▏             | 210/335 [01:19<00:39,  3.13it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  63%|███████████████████████▎             | 211/335 [01:20<00:40,  3.09it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  63%|██████████████████████▊             | 212/335 [01:20<00:40,  3.03it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  64%|██████████████████████▉             | 213/335 [01:20<00:40,  2.99it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  64%|██████████████████████▉             | 214/335 [01:21<00:40,  2.99it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  64%|███████████████████████             | 215/335 [01:21<00:39,  3.04it/s, Accuracy=0.969]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  64%|███████████████████████▏            | 216/335 [01:21<00:38,  3.09it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  65%|███████████████████████▎            | 217/335 [01:21<00:37,  3.17it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  65%|███████████████████████▍            | 218/335 [01:22<00:36,  3.18it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  65%|███████████████████████▌            | 219/335 [01:22<00:35,  3.24it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  66%|████████████████████████▎            | 220/335 [01:22<00:35,  3.27it/s, Accuracy=0.75]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  66%|███████████████████████▋            | 221/335 [01:23<00:34,  3.32it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  66%|███████████████████████▊            | 222/335 [01:23<00:33,  3.35it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  67%|███████████████████████▉            | 223/335 [01:23<00:33,  3.36it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  67%|████████████████████████            | 224/335 [01:24<00:35,  3.16it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  67%|████████████████████████▏           | 225/335 [01:24<00:34,  3.21it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  67%|████████████████████████▎           | 226/335 [01:24<00:34,  3.20it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  68%|████████████████████████▍           | 227/335 [01:25<00:33,  3.19it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  68%|████████████████████████▌           | 228/335 [01:25<00:33,  3.15it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  68%|████████████████████████▌           | 229/335 [01:25<00:33,  3.18it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  69%|████████████████████████▋           | 230/335 [01:25<00:32,  3.19it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  69%|████████████████████████▊           | 231/335 [01:26<00:32,  3.24it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  69%|████████████████████████▉           | 232/335 [01:26<00:31,  3.29it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  70%|█████████████████████████           | 233/335 [01:26<00:30,  3.35it/s, Accuracy=0.656]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  70%|█████████████████████████▏          | 234/335 [01:27<00:30,  3.36it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  70%|█████████████████████████▎          | 235/335 [01:27<00:29,  3.34it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  70%|█████████████████████████▎          | 236/335 [01:27<00:29,  3.34it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  71%|█████████████████████████▍          | 237/335 [01:28<00:29,  3.36it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  71%|█████████████████████████▌          | 238/335 [01:28<00:28,  3.40it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  71%|█████████████████████████▋          | 239/335 [01:28<00:27,  3.44it/s, Accuracy=0.719]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  72%|█████████████████████████▊          | 240/335 [01:28<00:27,  3.47it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  72%|█████████████████████████▉          | 241/335 [01:29<00:27,  3.45it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  72%|██████████████████████████          | 242/335 [01:29<00:27,  3.41it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  73%|██████████████████████████          | 243/335 [01:29<00:27,  3.37it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  73%|██████████████████████████▏         | 244/335 [01:30<00:27,  3.34it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  73%|██████████████████████████▎         | 245/335 [01:30<00:27,  3.30it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  73%|██████████████████████████▍         | 246/335 [01:30<00:27,  3.28it/s, Accuracy=0.938]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  74%|██████████████████████████▌         | 247/335 [01:31<00:27,  3.23it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  74%|██████████████████████████▋         | 248/335 [01:31<00:27,  3.16it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  74%|██████████████████████████▊         | 249/335 [01:31<00:26,  3.19it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  75%|██████████████████████████▊         | 250/335 [01:31<00:26,  3.23it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  75%|██████████████████████████▉         | 251/335 [01:32<00:26,  3.19it/s, Accuracy=0.625]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  75%|███████████████████████████         | 252/335 [01:32<00:26,  3.16it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  76%|███████████████████████████▏        | 253/335 [01:32<00:26,  3.14it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  76%|███████████████████████████▎        | 254/335 [01:33<00:26,  3.08it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  76%|███████████████████████████▍        | 255/335 [01:33<00:26,  3.04it/s, Accuracy=0.844]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  76%|███████████████████████████▌        | 256/335 [01:33<00:26,  3.03it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  77%|███████████████████████████▌        | 257/335 [01:34<00:25,  3.08it/s, Accuracy=0.906]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  77%|███████████████████████████▋        | 258/335 [01:34<00:24,  3.19it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  77%|███████████████████████████▊        | 259/335 [01:34<00:23,  3.24it/s, Accuracy=0.781]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  78%|███████████████████████████▉        | 260/335 [01:35<00:23,  3.20it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  78%|████████████████████████████        | 261/335 [01:35<00:22,  3.23it/s, Accuracy=0.812]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  78%|████████████████████████████▏       | 262/335 [01:35<00:22,  3.25it/s, Accuracy=0.688]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  79%|████████████████████████████▎       | 263/335 [01:36<00:22,  3.25it/s, Accuracy=0.875]

[None, None, None, tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='mps:0')]
Optimizing Gate Layer 0
Optimizing Gate Layer 1
Optimizing Gate Layer 2


Epoch 0:  79%|████████████████████████████▎       | 264/335 [01:36<00:21,  3.26it/s, Accuracy=0.688]

In [None]:
# import time

# # TODO: replace this with a transfer learned resnet model

# original_model = models.resnet50(pretrained=True)

# nonEarlyExitModel = EarlyExitModel(original_model, 1000, device)
# nonEarlyExitModel.to(device)
# nonModelTrainer = ModelTrainer(nonEarlyExitModel, device)

# # validate the model
# print("Validating original Resnet model")
# start = time.time()
# loss, acc, exits = nonModelTrainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Validation time: {end - start}")
# print("=====================================================")

# # validate the new early exit model
# print("Validating new ResnetEE model")
# start = time.time()
# loss, acc = trainer.validate(test_dataloader)
# end = time.time()
# print(f"Validation Loss: {loss}, Validation Accuracy: {acc}")
# print(f"Average exit index: {exits}")
# print(f"Validation time: {end - start}")

Validating original Resnet model


KeyboardInterrupt: 