In [2]:
pip install spuco --upgrade

Collecting spuco
  Downloading spuco-2.0.3-py3-none-any.whl.metadata (4.3 kB)
Downloading spuco-2.0.3-py3-none-any.whl (127 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.4/127.4 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spuco
  Attempting uninstall: spuco
    Found existing installation: spuco 2.0.2
    Uninstalling spuco-2.0.2:
      Successfully uninstalled spuco-2.0.2
Successfully installed spuco-2.0.3


In [3]:
import os
import torch
import pandas as pd
import torchvision.transforms as transforms
from torch.optim import SGD
from wilds import get_dataset

from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from spuco.evaluate import Evaluator
from spuco.group_inference import SpareInference
from spuco.robust_train import SpareTrain
from spuco.models import model_factory
from spuco.utils import Trainer, set_seed

In [4]:
params = {
    "gpu": 0,
    "seed": 0,
    "root_dir": "/data",
    "batch_size": 32,
    "num_epochs": 20,
    "lr": 1e-3,
    "weight_decay": 1e-2,
    "momentum": 0.9,
    "pretrained": False,
    "infer_lr": 1e-3,
    "infer_weight_decay": 1e-2,
    "infer_momentum": 0.9,
    "infer_num_epochs": 1,
    "high_sampling_power": 2,
}

In [5]:
device = torch.device(f"cuda:{params['gpu']}" if torch.cuda.is_available() else "cpu")
set_seed(params["seed"])

In [6]:
classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
difficulty = SpuriousFeatureDifficulty.MAGNITUDE_LARGE

trainset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    spurious_correlation_strength=0.995,
    classes=classes,
    split="train"
)
trainset.initialize()

valset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="val",
)
valset.initialize()

testset = SpuCoMNIST(
    root=f"{params['root_dir']}/mnist/",
    spurious_feature_difficulty=difficulty,
    classes=classes,
    split="test"
)
testset.initialize()


In [7]:
model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes, pretrained=params["pretrained"]).to(device)

trainer = Trainer(
    trainset=trainset,
    model=model,
    batch_size=params["batch_size"],
    optimizer=SGD(model.parameters(), lr=params["infer_lr"], weight_decay=params["infer_weight_decay"], momentum=params["infer_momentum"]),
    device=device,
    verbose=True
)

trainer.train(num_epochs=params["infer_num_epochs"])

Epoch 0:   1%|          | 10/1501 [00:01<03:05,  8.06batch/s, accuracy=25.0%, loss=1.61]  

 | Epoch 0 | Loss: 1.6115756034851074 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6203930377960205 | Accuracy: 12.5%
 | Epoch 0 | Loss: 1.61808180809021 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6209357976913452 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.6147788763046265 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.6116019487380981 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6081494092941284 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.607947826385498 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.616422414779663 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.6171921491622925 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.618257761001587 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.618667483329773 | Accuracy: 12.5%
 | Epoch 0 | Loss: 1.6032153367996216 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.60996413230896 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6192119121551514 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6320838928222656 | Accuracy: 12.5%
 | Epoch 0 | Loss: 1.6056013107299805 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.6140894889831

Epoch 0:   3%|▎         | 40/1501 [00:01<00:36, 39.54batch/s, accuracy=43.75%, loss=1.6]

 | Epoch 0 | Loss: 1.5921324491500854 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6107178926467896 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.608851671218872 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.6068270206451416 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.6245592832565308 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6247491836547852 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6223483085632324 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.5961277484893799 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5983879566192627 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.610825777053833 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6113169193267822 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6118828058242798 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6016209125518799 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6103163957595825 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.5956121683120728 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.613802433013916 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.6160727739334106 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.6188437

Epoch 0:   5%|▍         | 68/1501 [00:02<00:20, 70.50batch/s, accuracy=40.625%, loss=1.59]

 | Epoch 0 | Loss: 1.5989820957183838 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.6012864112854004 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.6103123426437378 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.6135225296020508 | Accuracy: 15.625%
 | Epoch 0 | Loss: 1.5965113639831543 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.6070175170898438 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.584710955619812 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.6017626523971558 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.6111871004104614 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6063504219055176 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6156306266784668 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.6080412864685059 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.593406319618225 | Accuracy: 46.875%
 | Epoch 0 | Loss: 1.598187804222107 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.6087473630905151 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.607277750968933 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.6010570526123047 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.60

Epoch 0:   6%|▋         | 94/1501 [00:02<00:15, 88.46batch/s, accuracy=25.0%, loss=1.59]

 | Epoch 0 | Loss: 1.588773488998413 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.597916603088379 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.5972654819488525 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.5973671674728394 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6063809394836426 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.606278419494629 | Accuracy: 12.5%
 | Epoch 0 | Loss: 1.6049405336380005 | Accuracy: 18.75%
 | Epoch 0 | Loss: 1.6054733991622925 | Accuracy: 9.375%
 | Epoch 0 | Loss: 1.6035130023956299 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.585585355758667 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.5960434675216675 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.5943890810012817 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.6091092824935913 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.6001728773117065 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.5974438190460205 | Accuracy: 21.875%
 | Epoch 0 | Loss: 1.5901000499725342 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.5994906425476074 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.610674500465

Epoch 0:   8%|▊         | 118/1501 [00:02<00:14, 98.70batch/s, accuracy=43.75%, loss=1.59]

 | Epoch 0 | Loss: 1.6134130954742432 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.5980520248413086 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.6047884225845337 | Accuracy: 25.0%
 | Epoch 0 | Loss: 1.592774748802185 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.61192786693573 | Accuracy: 28.125%
 | Epoch 0 | Loss: 1.5979927778244019 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.5957067012786865 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.5911368131637573 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.5934851169586182 | Accuracy: 43.75%
 | Epoch 0 | Loss: 1.5949937105178833 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.5980761051177979 | Accuracy: 34.375%
 | Epoch 0 | Loss: 1.6082849502563477 | Accuracy: 31.25%
 | Epoch 0 | Loss: 1.5985790491104126 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5859853029251099 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.5891332626342773 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.5914463996887207 | Accuracy: 43.75%
 | Epoch 0 | Loss: 1.589217185974121 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.591804742813110

Epoch 0:   9%|▉         | 142/1501 [00:02<00:13, 103.81batch/s, accuracy=46.875%, loss=1.59]

 | Epoch 0 | Loss: 1.5878008604049683 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.6067323684692383 | Accuracy: 43.75%
 | Epoch 0 | Loss: 1.5906286239624023 | Accuracy: 40.625%
 | Epoch 0 | Loss: 1.5863475799560547 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.592027187347412 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.5967721939086914 | Accuracy: 56.25%
 | Epoch 0 | Loss: 1.5785636901855469 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.5879312753677368 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5901931524276733 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5917553901672363 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.5897191762924194 | Accuracy: 56.25%
 | Epoch 0 | Loss: 1.6035561561584473 | Accuracy: 46.875%
 | Epoch 0 | Loss: 1.600411057472229 | Accuracy: 46.875%
 | Epoch 0 | Loss: 1.588857650756836 | Accuracy: 46.875%
 | Epoch 0 | Loss: 1.5859441757202148 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5821863412857056 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.5866854190826416 | Accuracy: 56.25%
 | Epoch 0 | Loss: 1.58694601

Epoch 0:  11%|█         | 166/1501 [00:03<00:12, 109.22batch/s, accuracy=75.0%, loss=1.56] 

 | Epoch 0 | Loss: 1.5874364376068115 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5788695812225342 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5903457403182983 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5887634754180908 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5838472843170166 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.577512502670288 | Accuracy: 68.75%
 | Epoch 0 | Loss: 1.579419732093811 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5803251266479492 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.5908070802688599 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5796573162078857 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5855544805526733 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5736178159713745 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.5895360708236694 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.591624140739441 | Accuracy: 37.5%
 | Epoch 0 | Loss: 1.5849344730377197 | Accuracy: 56.25%
 | Epoch 0 | Loss: 1.5848298072814941 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.582426905632019 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.584951400756836

Epoch 0:  13%|█▎        | 191/1501 [00:03<00:11, 114.48batch/s, accuracy=65.625%, loss=1.57]

 | Epoch 0 | Loss: 1.5771241188049316 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.5698962211608887 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.5787192583084106 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5815234184265137 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5832970142364502 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.571047306060791 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5777735710144043 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.5781500339508057 | Accuracy: 46.875%
 | Epoch 0 | Loss: 1.5699137449264526 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5753742456436157 | Accuracy: 56.25%
 | Epoch 0 | Loss: 1.5649160146713257 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5709733963012695 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.574808120727539 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5778309106826782 | Accuracy: 53.125%
 | Epoch 0 | Loss: 1.5690107345581055 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5681706666946411 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5649131536483765 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.5648

Epoch 0:  14%|█▍        | 217/1501 [00:03<00:10, 120.07batch/s, accuracy=65.625%, loss=1.55]

 | Epoch 0 | Loss: 1.5658234357833862 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5578656196594238 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.5733445882797241 | Accuracy: 50.0%
 | Epoch 0 | Loss: 1.5621864795684814 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.553271770477295 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.5683434009552002 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.5619651079177856 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5621269941329956 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5579731464385986 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.5494132041931152 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.559434413909912 | Accuracy: 59.375%
 | Epoch 0 | Loss: 1.5492074489593506 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.5465412139892578 | Accuracy: 90.625%
 | Epoch 0 | Loss: 1.5524706840515137 | Accuracy: 68.75%
 | Epoch 0 | Loss: 1.540819764137268 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.5498888492584229 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.5490976572036743 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.54274368

Epoch 0:  16%|█▌        | 243/1501 [00:03<00:10, 122.95batch/s, accuracy=75.0%, loss=1.49]  

 | Epoch 0 | Loss: 1.5512109994888306 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5353953838348389 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.5367742776870728 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.5477920770645142 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5391998291015625 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.542037010192871 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.5432443618774414 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.538392186164856 | Accuracy: 65.625%
 | Epoch 0 | Loss: 1.5307666063308716 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.5329158306121826 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.5362470149993896 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.5042380094528198 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.519201397895813 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.5197125673294067 | Accuracy: 90.625%
 | Epoch 0 | Loss: 1.5308271646499634 | Accuracy: 68.75%
 | Epoch 0 | Loss: 1.5263131856918335 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.5276708602905273 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.529

Epoch 0:  18%|█▊        | 271/1501 [00:03<00:09, 127.72batch/s, accuracy=78.125%, loss=1.45]

 | Epoch 0 | Loss: 1.494246482849121 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.5138485431671143 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.4906089305877686 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.5125597715377808 | Accuracy: 68.75%
 | Epoch 0 | Loss: 1.5048381090164185 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.49208402633667 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.4998717308044434 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.497685194015503 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.4830543994903564 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.4764527082443237 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.4818716049194336 | Accuracy: 100.0%
 | Epoch 0 | Loss: 1.4939981698989868 | Accuracy: 90.625%
 | Epoch 0 | Loss: 1.5039234161376953 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.471963882446289 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.4984185695648193 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.464093565940857 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.4742937088012695 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.481552958488

Epoch 0:  20%|█▉        | 299/1501 [00:04<00:09, 129.24batch/s, accuracy=75.0%, loss=1.35] 

 | Epoch 0 | Loss: 1.42391836643219 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.4376436471939087 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.4506468772888184 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.378679871559143 | Accuracy: 68.75%
 | Epoch 0 | Loss: 1.4068098068237305 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.4042621850967407 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.392410397529602 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.4527709484100342 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.3725112676620483 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.4138108491897583 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.4069403409957886 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.3812891244888306 | Accuracy: 62.5%
 | Epoch 0 | Loss: 1.3493672609329224 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.3860976696014404 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.3538377285003662 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.397043228149414 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.3487412929534912 | Accuracy: 71.875%
 | Epoch 0 | Loss: 1.3724021911621094 

Epoch 0:  22%|██▏       | 327/1501 [00:04<00:09, 123.34batch/s, accuracy=81.25%, loss=0.946]

 | Epoch 0 | Loss: 1.244280457496643 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.296396017074585 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.2009172439575195 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.184575080871582 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.2691739797592163 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.1918134689331055 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.2245111465454102 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.2190791368484497 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.1568325757980347 | Accuracy: 75.0%
 | Epoch 0 | Loss: 1.1378740072250366 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.1751667261123657 | Accuracy: 90.625%
 | Epoch 0 | Loss: 1.1242884397506714 | Accuracy: 87.5%
 | Epoch 0 | Loss: 1.121272325515747 | Accuracy: 84.375%
 | Epoch 0 | Loss: 1.1299874782562256 | Accuracy: 90.625%
 | Epoch 0 | Loss: 1.0904096364974976 | Accuracy: 93.75%
 | Epoch 0 | Loss: 1.119937539100647 | Accuracy: 81.25%
 | Epoch 0 | Loss: 1.0560507774353027 | Accuracy: 78.125%
 | Epoch 0 | Loss: 1.0735065937042236 

Epoch 0:  23%|██▎       | 340/1501 [00:04<00:09, 121.70batch/s, accuracy=100.0%, loss=0.535]

 | Epoch 0 | Loss: 0.9245226979255676 | Accuracy: 78.125%
 | Epoch 0 | Loss: 0.8876541256904602 | Accuracy: 71.875%
 | Epoch 0 | Loss: 0.9789205193519592 | Accuracy: 75.0%
 | Epoch 0 | Loss: 0.9272903800010681 | Accuracy: 68.75%
 | Epoch 0 | Loss: 0.8447571992874146 | Accuracy: 78.125%
 | Epoch 0 | Loss: 0.9438467025756836 | Accuracy: 71.875%
 | Epoch 0 | Loss: 0.9137310981750488 | Accuracy: 78.125%
 | Epoch 0 | Loss: 0.850742757320404 | Accuracy: 81.25%
 | Epoch 0 | Loss: 0.713824450969696 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.7819921374320984 | Accuracy: 84.375%
 | Epoch 0 | Loss: 0.8621842861175537 | Accuracy: 71.875%
 | Epoch 0 | Loss: 0.8750942945480347 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.7728601098060608 | Accuracy: 90.625%
 | Epoch 0 | Loss: 0.8187344074249268 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.7374910712242126 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.7110127210617065 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.6296552419662476 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.7315

Epoch 0:  24%|██▍       | 365/1501 [00:04<00:09, 115.74batch/s, accuracy=100.0%, loss=0.215]

 | Epoch 0 | Loss: 0.48394396901130676 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.5783718228340149 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.47669893503189087 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.44210898876190186 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.49390313029289246 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.49432671070098877 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.408426970243454 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.36981579661369324 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.38391363620758057 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.4082539677619934 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.36868807673454285 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.32698309421539307 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.35057884454727173 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3121744394302368 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.36249998211860657 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2712835967540741 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2590537965297699 | Accuracy: 100.0%
 | Epoch 0 | Loss: 

Epoch 0:  26%|██▌       | 389/1501 [00:04<00:10, 111.03batch/s, accuracy=100.0%, loss=0.0776]

 | Epoch 0 | Loss: 0.40772178769111633 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.3269735872745514 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.1917312741279602 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1629597246646881 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15175040066242218 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14239513874053955 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15942515432834625 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.12923282384872437 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.12799301743507385 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14384178817272186 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.11539483070373535 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.10945018380880356 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.09838155657052994 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3754342794418335 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.0951748788356781 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.10665646195411682 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.08674688637256622 | Accuracy: 100.0%
 | Epoch 0 | Los

Epoch 0:  28%|██▊       | 413/1501 [00:05<00:09, 110.12batch/s, accuracy=100.0%, loss=0.037] 

 | Epoch 0 | Loss: 0.07759734988212585 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14255967736244202 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.06981029361486435 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.06927680224180222 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.06037714332342148 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.05930996313691139 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.049939993768930435 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.05939050018787384 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.06742853671312332 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.06426158547401428 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.05474616959691048 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.05025101453065872 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0562596321105957 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.042262256145477295 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.051551174372434616 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.048168446868658066 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0482742078602314 | Accuracy: 100.0%
 | Epoch 0 

Epoch 0:  29%|██▉       | 437/1501 [00:05<00:09, 107.90batch/s, accuracy=96.875%, loss=0.289]

 | Epoch 0 | Loss: 0.03562542796134949 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.034421734511852264 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.027295485138893127 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.028050892055034637 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.34640270471572876 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.03125879168510437 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.030109791085124016 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02808523178100586 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.03145135939121246 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3559907078742981 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.04267902672290802 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.03299331292510033 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.039655644446611404 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1675887107849121 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.03269200399518013 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.030845563858747482 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.024727873504161835 | Accuracy: 100.0%
 | Epo

Epoch 0:  31%|███       | 460/1501 [00:05<00:09, 107.33batch/s, accuracy=100.0%, loss=0.0182]


 | Epoch 0 | Loss: 0.19382308423519135 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.02526698261499405 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1552530825138092 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.02149275131523609 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02033202163875103 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.027232391759753227 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.024674471467733383 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02510741353034973 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01804507151246071 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02009313367307186 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.23635233938694 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.3375712037086487 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.03193731606006622 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.028037717565894127 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.021214507520198822 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.025803951546549797 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.021832704544067383 | Accuracy: 100.0%
 | Epoch

Epoch 0:  32%|███▏      | 484/1501 [00:05<00:09, 110.01batch/s, accuracy=100.0%, loss=0.0166]

 | Epoch 0 | Loss: 0.017766073346138 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.022786777466535568 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.31642061471939087 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.017873531207442284 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019989201799035072 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.5641944408416748 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.017572127282619476 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.021623268723487854 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01834195666015148 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.018977105617523193 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01828288473188877 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016382943838834763 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019181804731488228 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01827075518667698 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016427651047706604 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.30775272846221924 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01976027898490429 | Accuracy: 100.0%
 | Epo

Epoch 0:  34%|███▍      | 508/1501 [00:05<00:08, 111.64batch/s, accuracy=100.0%, loss=0.0158]

 | Epoch 0 | Loss: 0.023706842213869095 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02632697857916355 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.027882276102900505 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02098163776099682 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1343492865562439 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.02890138514339924 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.033485542982816696 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.027889279648661613 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02826441265642643 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.321054071187973 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.020607132464647293 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.025764934718608856 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02168629691004753 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.022098930552601814 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2264510989189148 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.016306251287460327 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019317716360092163 | Accuracy: 100.0%
 | Epo

Epoch 0:  35%|███▌      | 532/1501 [00:06<00:08, 112.04batch/s, accuracy=100.0%, loss=0.0162]

 | Epoch 0 | Loss: 0.01582329161465168 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01770750619471073 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017962709069252014 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.08477535098791122 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.018478747457265854 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.020024124532938004 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01868857629597187 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02304360643029213 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01821010373532772 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019204095005989075 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.018952006474137306 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0227805282920599 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.021018600091338158 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01690371334552765 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014772752299904823 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017791949212551117 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01311005000025034 | Accuracy: 100.0%
 | Epo

Epoch 0:  37%|███▋      | 558/1501 [00:06<00:07, 119.87batch/s, accuracy=100.0%, loss=0.014] 

 | Epoch 0 | Loss: 0.01797799952328205 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016436053439974785 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1846659779548645 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01487478706985712 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.020543545484542847 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017927007749676704 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.5501402616500854 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.019957544282078743 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01585463061928749 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01835649646818638 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017378132790327072 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01847754791378975 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016108250245451927 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015603660605847836 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014848494902253151 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013535737060010433 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014866783283650875 | Accuracy: 100.0%
 | Ep

Epoch 0:  39%|███▉      | 584/1501 [00:06<00:07, 119.33batch/s, accuracy=100.0%, loss=0.0126]

 | Epoch 0 | Loss: 0.014658558182418346 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014316759072244167 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01633443683385849 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01454425323754549 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01841806061565876 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012324837036430836 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.018896808847784996 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014041266404092312 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1717582792043686 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012790489941835403 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012637677602469921 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.02000107429921627 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.5368663668632507 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.014521149918437004 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011706719174981117 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012639940716326237 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01323521975427866 | Accuracy: 100.0%
 | Ep

Epoch 0:  41%|████      | 610/1501 [00:06<00:07, 124.07batch/s, accuracy=100.0%, loss=0.0134]

 | Epoch 0 | Loss: 0.3472128212451935 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.40217310190200806 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013386254198849201 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017083967104554176 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01271265558898449 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015542017295956612 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013161392882466316 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1707298755645752 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013821577653288841 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012431937269866467 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012444636784493923 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015165350399911404 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013651460409164429 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013143660500645638 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15650396049022675 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013587314635515213 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15025684237480164 | Accuracy: 96.875%


Epoch 0:  43%|████▎     | 639/1501 [00:06<00:06, 132.93batch/s, accuracy=100.0%, loss=0.0139]

 | Epoch 0 | Loss: 0.010993404313921928 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.184810608625412 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.015356154181063175 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3464154899120331 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012532024644315243 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013821609318256378 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015795839950442314 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013140133582055569 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012511786073446274 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01029275357723236 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015063159167766571 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1862635463476181 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.32154256105422974 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011544662527740002 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17248040437698364 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01720547303557396 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012973028235137463 | Accuracy: 100.0%
 | 

Epoch 0:  44%|████▍     | 667/1501 [00:07<00:06, 125.91batch/s, accuracy=100.0%, loss=0.0143]

 | Epoch 0 | Loss: 0.18608619272708893 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013622465543448925 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17928853631019592 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.42961767315864563 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01804005168378353 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01699664816260338 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013728005811572075 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1604139506816864 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.015308544971048832 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01630103960633278 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015621479600667953 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1883067488670349 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.014991184696555138 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01611117273569107 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0159278754144907 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013102521188557148 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01788528449833393 | Accuracy: 100.0%
 | Epo

Epoch 0:  45%|████▌     | 680/1501 [00:07<00:06, 120.02batch/s, accuracy=100.0%, loss=0.0135]

 | Epoch 0 | Loss: 0.01742437854409218 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016997642815113068 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016103792935609818 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18053491413593292 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.0149566400796175 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01762687787413597 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019478704780340195 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01668490841984749 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01860620081424713 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017000945284962654 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015892410650849342 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015112679451704025 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01753619685769081 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01536036841571331 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015428084880113602 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014464722014963627 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01483739074319601 | Accuracy: 100.0%
 | Epo

Epoch 0:  47%|████▋     | 705/1501 [00:07<00:07, 107.76batch/s, accuracy=100.0%, loss=0.0143]

 | Epoch 0 | Loss: 0.015609168447554111 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014731451869010925 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013475575484335423 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014886024408042431 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013875841163098812 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1682426631450653 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01295189093798399 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014591388404369354 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010380118153989315 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3085472583770752 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013786926865577698 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011309251189231873 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19585396349430084 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.014263404533267021 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011345554143190384 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14310908317565918 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013931015506386757 | Accuracy: 100.0%


Epoch 0:  48%|████▊     | 727/1501 [00:07<00:07, 106.49batch/s, accuracy=100.0%, loss=0.0148]

 | Epoch 0 | Loss: 0.014897141605615616 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013394501060247421 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.021936781704425812 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014645972289144993 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14567966759204865 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.015883484855294228 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015572072938084602 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013442069292068481 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.019369129091501236 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015418601222336292 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.32628732919692993 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01343410462141037 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014969805255532265 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3502121865749359 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.17118433117866516 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011685573495924473 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011957385577261448 | Accuracy: 100.0%

Epoch 0:  50%|████▉     | 749/1501 [00:08<00:07, 106.71batch/s, accuracy=100.0%, loss=0.0143]

 | Epoch 0 | Loss: 0.014832800254225731 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014349599368870258 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01286096591502428 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01388863380998373 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01030761655420065 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17100033164024353 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.39104610681533813 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01482648216187954 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012156683951616287 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013266078196465969 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3341645896434784 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01019850093871355 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013543599285185337 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013962902128696442 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011608156375586987 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011165943928062916 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012449403293430805 | Accuracy: 100.0%
 | 

Epoch 0:  51%|█████▏    | 772/1501 [00:08<00:06, 107.04batch/s, accuracy=96.875%, loss=0.381]

 | Epoch 0 | Loss: 0.01434255950152874 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011500667780637741 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011958925984799862 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01126302219927311 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01214112900197506 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01115233264863491 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00943286158144474 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01126874890178442 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012681258842349052 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011452380567789078 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010740457102656364 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009499549865722656 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00964860524982214 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011182624846696854 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011004837229847908 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010915104299783707 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010522834956645966 | Accuracy: 100.0%
 | E

Epoch 0:  53%|█████▎    | 794/1501 [00:08<00:06, 105.32batch/s, accuracy=100.0%, loss=0.0104] 

 | Epoch 0 | Loss: 0.3805111050605774 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011689855717122555 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010194355621933937 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0104245999827981 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010135889053344727 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008618463762104511 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008704042062163353 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011509589850902557 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009140260517597198 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011559149250388145 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2173444926738739 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010091585107147694 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011075178161263466 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0100038917735219 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18394924700260162 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008530659601092339 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010597401298582554 | Accuracy: 100.0%
 | 

Epoch 0:  54%|█████▍    | 816/1501 [00:08<00:06, 103.02batch/s, accuracy=100.0%, loss=0.0152]

 | Epoch 0 | Loss: 0.010592803359031677 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.36514532566070557 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.16310498118400574 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013590286485850811 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01718418300151825 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012898930348455906 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013486270792782307 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016492849215865135 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011669759638607502 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009990437887609005 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01445665955543518 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3462308943271637 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01227385364472866 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01627965085208416 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016956273466348648 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017858903855085373 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.13644781708717346 | Accuracy: 96.875%
 |

Epoch 0:  56%|█████▌    | 838/1501 [00:08<00:06, 104.30batch/s, accuracy=100.0%, loss=0.0125]

 | Epoch 0 | Loss: 0.014460006728768349 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013870224356651306 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.12676255404949188 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01670435257256031 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012337963096797466 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01248534582555294 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01006396021693945 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010990222916007042 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.016414230689406395 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19756712019443512 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01176657248288393 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009870229288935661 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010521775111556053 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011545306071639061 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3736415207386017 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013529853895306587 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010948899202048779 | Accuracy: 100.0%
 |

Epoch 0:  57%|█████▋    | 862/1501 [00:09<00:05, 111.76batch/s, accuracy=100.0%, loss=0.0104]

 | Epoch 0 | Loss: 0.012518414296209812 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010345295071601868 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20912760496139526 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.16940906643867493 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.00970704760402441 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010113074444234371 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008691266179084778 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008789021521806717 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011152674444019794 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009374663233757019 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012940740212798119 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010265355929732323 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010447058826684952 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008765406906604767 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0100067974999547 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011004319414496422 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00959334522485733 | Accuracy: 100.0%
 

Epoch 0:  59%|█████▉    | 887/1501 [00:09<00:05, 112.55batch/s, accuracy=100.0%, loss=0.0128]

 | Epoch 0 | Loss: 0.1842639148235321 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01177921611815691 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0088410135358572 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009795069694519043 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009324666112661362 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008593035861849785 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2057836651802063 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.340947687625885 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009402060881257057 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009876783937215805 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.16153989732265472 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009432748891413212 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3792760372161865 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010863479226827621 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012523484416306019 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010080319829285145 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012698461301624775 | Accuracy: 100.0%
 | Ep

Epoch 0:  61%|██████    | 911/1501 [00:09<00:05, 113.90batch/s, accuracy=96.875%, loss=0.166] 

 | Epoch 0 | Loss: 0.36166736483573914 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.014417709782719612 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01572539657354355 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012118151411414146 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012880787253379822 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01382146030664444 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012774322181940079 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013421441428363323 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015350875444710255 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011761204339563847 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01259295828640461 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01217648759484291 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014465982094407082 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011405459605157375 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012007682584226131 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011954096145927906 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012116492725908756 | Accuracy: 100.0%
 

Epoch 0:  62%|██████▏   | 936/1501 [00:09<00:04, 116.62batch/s, accuracy=100.0%, loss=0.0119]

 | Epoch 0 | Loss: 0.16568510234355927 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011939948424696922 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012218107469379902 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011002110317349434 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.312520295381546 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011542536318302155 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01258278638124466 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013231185264885426 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17047318816184998 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01577344909310341 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010750389657914639 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01009330153465271 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012407013215124607 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1495012789964676 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011116262525320053 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010848002508282661 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014792632311582565 | Accuracy: 100.0%
 | 

Epoch 0:  63%|██████▎   | 949/1501 [00:09<00:04, 118.78batch/s, accuracy=100.0%, loss=0.0134]

 | Epoch 0 | Loss: 0.011876644566655159 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.31810081005096436 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012277377769351006 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0109056131914258 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011300340294837952 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010851157829165459 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010228568688035011 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010148501954972744 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011427429504692554 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012517385184764862 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009784414432942867 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.4009992480278015 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01114380732178688 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009368949569761753 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19295461475849152 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011536319740116596 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010546203702688217 | Accuracy: 100.0%
 

Epoch 0:  65%|██████▍   | 974/1501 [00:10<00:04, 116.22batch/s, accuracy=100.0%, loss=0.00843]

 | Epoch 0 | Loss: 0.013354331254959106 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012990192510187626 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011261584237217903 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011490609496831894 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012386168353259563 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012074407190084457 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014512753114104271 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011769036762416363 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012331030331552029 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011251355521380901 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009891994297504425 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010014092549681664 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01104283332824707 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01001318171620369 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00955538172274828 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013952543959021568 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010423432104289532 | Accuracy: 100.0%


Epoch 0:  66%|██████▋   | 998/1501 [00:10<00:04, 104.01batch/s, accuracy=100.0%, loss=0.00885]

 | Epoch 0 | Loss: 0.010080062784254551 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00823170505464077 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008192737586796284 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00857534445822239 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012169837020337582 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008830495178699493 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008926261216402054 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18542250990867615 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008699235506355762 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17708583176136017 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007993458770215511 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20129019021987915 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01078227162361145 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00858288910239935 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008538000285625458 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008519822731614113 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00756741501390934 | Accuracy: 100.0%
 |

Epoch 0:  67%|██████▋   | 1009/1501 [00:10<00:05, 98.06batch/s, accuracy=100.0%, loss=0.0109]

 | Epoch 0 | Loss: 0.008854924701154232 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.5004823207855225 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.008146200329065323 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15768297016620636 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010801395401358604 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011189996264874935 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009997081942856312 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009416749700903893 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00952966045588255 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010993282310664654 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.018283477053046227 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010763571597635746 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010276506654918194 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1677081286907196 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013021713122725487 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17674629390239716 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010861323215067387 | Accuracy: 100.0%


Epoch 0:  69%|██████▊   | 1029/1501 [00:10<00:05, 91.81batch/s, accuracy=96.875%, loss=0.178]

 | Epoch 0 | Loss: 0.010351932607591152 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009627451188862324 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010910946875810623 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013280991464853287 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013688664883375168 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010206733830273151 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.31785815954208374 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009499629959464073 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012175626121461391 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00917481817305088 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3785196840763092 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012618673965334892 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010482940822839737 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.32114794850349426 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013357487507164478 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015555859543383121 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011587261222302914 | Accuracy: 100.0%

Epoch 0:  70%|██████▉   | 1049/1501 [00:10<00:04, 92.66batch/s, accuracy=100.0%, loss=0.0111]

 | Epoch 0 | Loss: 0.17841756343841553 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011419848538935184 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011825695633888245 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1898212730884552 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.3580626845359802 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.01033059973269701 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01020248606801033 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013311902992427349 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011789094656705856 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012625092640519142 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011983541771769524 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014413739554584026 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013044056482613087 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011784432455897331 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01193759124726057 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012790913693606853 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010577531531453133 | Accuracy: 100.0%
 | 

Epoch 0:  71%|███████   | 1069/1501 [00:11<00:04, 89.71batch/s, accuracy=100.0%, loss=0.0126]

 | Epoch 0 | Loss: 0.011080905795097351 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014654325321316719 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2717898488044739 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012267820537090302 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013722897507250309 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012735776603221893 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010365702211856842 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012378417886793613 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18564872443675995 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011479705572128296 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.36330628395080566 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012743860483169556 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011040423065423965 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011888710781931877 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009530513547360897 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011034289374947548 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.31529727578163147 | Accuracy: 96.875

Epoch 0:  73%|███████▎  | 1089/1501 [00:11<00:04, 90.27batch/s, accuracy=100.0%, loss=0.0117]

 | Epoch 0 | Loss: 0.012584433890879154 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012384532019495964 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012897173874080181 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010232703760266304 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010861271061003208 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012162690982222557 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01236516423523426 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011070970445871353 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3208193778991699 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011521791107952595 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012235647067427635 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17410095036029816 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012402449734508991 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011643008328974247 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19291222095489502 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013426998630166054 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2689395546913147 | Accuracy: 96.875%


Epoch 0:  74%|███████▍  | 1109/1501 [00:11<00:04, 91.81batch/s, accuracy=100.0%, loss=0.00963]

 | Epoch 0 | Loss: 0.013285432010889053 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01184152066707611 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011879818513989449 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013375787995755672 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015643151476979256 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014205857180058956 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011832830496132374 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.015042591840028763 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013761145994067192 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012418008409440517 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012339645996689796 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009525171481072903 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012198719196021557 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013373564928770065 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011748025193810463 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011133279651403427 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.29797741770744324 | Accuracy: 96.875

Epoch 0:  75%|███████▌  | 1129/1501 [00:11<00:04, 90.92batch/s, accuracy=100.0%, loss=0.0123]

 | Epoch 0 | Loss: 0.009633103385567665 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011202679015696049 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009943347424268723 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010991369374096394 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011582502163946629 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008999346755445004 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008871225640177727 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01256723701953888 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01245688647031784 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009621940553188324 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17211075127124786 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01019217073917389 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011904907412827015 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010444670915603638 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010359539650380611 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011152056977152824 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18316256999969482 | Accuracy: 96.875%


Epoch 0:  77%|███████▋  | 1151/1501 [00:11<00:03, 97.15batch/s, accuracy=100.0%, loss=0.0106]

 | Epoch 0 | Loss: 0.17013011872768402 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009455062448978424 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010103748179972172 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01164247840642929 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01086469180881977 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010706109926104546 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012150094844400883 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010451758280396461 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010938293300569057 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009646565653383732 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009217847138643265 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00898592360317707 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013883729465305805 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009448830969631672 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010187443345785141 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010067585855722427 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011122331954538822 | Accuracy: 100.0%


Epoch 0:  78%|███████▊  | 1173/1501 [00:12<00:03, 100.83batch/s, accuracy=100.0%, loss=0.0116]

 | Epoch 0 | Loss: 0.34019166231155396 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010713859461247921 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20289461314678192 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.3067461848258972 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011902578175067902 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008695591241121292 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010821455158293247 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010388470254838467 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00903987791389227 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011393758468329906 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011492803692817688 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011626400984823704 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009215361438691616 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009525880217552185 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009491595439612865 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009781253524124622 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011080394499003887 | Accuracy: 100.0%

Epoch 0:  79%|███████▉  | 1184/1501 [00:12<00:03, 94.39batch/s, accuracy=100.0%, loss=0.0103]

 | Epoch 0 | Loss: 0.009880437515676022 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010331588797271252 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008871070109307766 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2328660935163498 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011359588243067265 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3201812207698822 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011380271054804325 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010668252594769001 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0114262904971838 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011409359984099865 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010384532622992992 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008574835024774075 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008018475957214832 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007924177683889866 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008532613515853882 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008671069517731667 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008998310193419456 | Accuracy: 100.0%
 

Epoch 0:  80%|████████  | 1204/1501 [00:12<00:03, 92.38batch/s, accuracy=100.0%, loss=0.0094]

 | Epoch 0 | Loss: 0.009089923463761806 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010579414665699005 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0106138801202178 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.013193869031965733 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19709539413452148 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.0085951192304492 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00954881776124239 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008075599558651447 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009967230260372162 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0083076860755682 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008920806460082531 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009645169600844383 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008733700029551983 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2078837752342224 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010154154151678085 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010131516493856907 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008929328061640263 | Accuracy: 100.0%
 | Ep

Epoch 0:  82%|████████▏ | 1224/1501 [00:12<00:02, 93.67batch/s, accuracy=96.875%, loss=0.306]

 | Epoch 0 | Loss: 0.014444326981902122 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1892651617527008 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008940573781728745 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010899432934820652 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010956876911222935 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009235622361302376 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010839651338756084 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011558313854038715 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009089208208024502 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009388419799506664 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010849550366401672 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1457633227109909 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011416184715926647 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012945793569087982 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010456916876137257 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010384674184024334 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008510606363415718 | Accuracy: 100.0%

Epoch 0:  83%|████████▎ | 1244/1501 [00:12<00:02, 92.74batch/s, accuracy=100.0%, loss=0.00948]

 | Epoch 0 | Loss: 0.01064327359199524 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009250926785171032 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010275332257151604 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010813530534505844 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011092206463217735 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010026581585407257 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010096621699631214 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.42684346437454224 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008452304638922215 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006906190887093544 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008442184887826443 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008464205078780651 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008227930404245853 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19909575581550598 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.3345261812210083 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.00708750681951642 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010888693854212761 | Accuracy: 100.0%


Epoch 0:  84%|████████▍ | 1264/1501 [00:13<00:02, 88.30batch/s, accuracy=96.875%, loss=0.177]

 | Epoch 0 | Loss: 0.010096014477312565 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010774709284305573 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007021407596766949 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19851045310497284 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01261818502098322 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010445253923535347 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01017578225582838 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010227477177977562 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009213346987962723 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3897586464881897 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011466002091765404 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.31205829977989197 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010243878699839115 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009824617765843868 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009835141710937023 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.25154995918273926 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.35713428258895874 | Accuracy: 96.875%


Epoch 0:  85%|████████▍ | 1273/1501 [00:13<00:02, 83.82batch/s, accuracy=100.0%, loss=0.0115]

 | Epoch 0 | Loss: 0.010341105051338673 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009616462513804436 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010859346017241478 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010178840719163418 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011861833743751049 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009370873682200909 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2227313071489334 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010172328911721706 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.14607930183410645 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01275066938251257 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010648161172866821 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011367971077561378 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011694875545799732 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3289542496204376 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.01083772536367178 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011504379101097584 | Accuracy: 100.0%

Epoch 0:  86%|████████▌ | 1291/1501 [00:13<00:02, 77.45batch/s, accuracy=100.0%, loss=0.0129]


 | Epoch 0 | Loss: 0.010587580502033234 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.13722968101501465 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.011379027739167213 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01209008414298296 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012545670382678509 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011866285465657711 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010202910751104355 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0103544220328331 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.15324215590953827 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.013551556505262852 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009655854664742947 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011588849127292633 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01267866138368845 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012852724641561508 | Accuracy: 100.0%


Epoch 0:  87%|████████▋ | 1309/1501 [00:13<00:02, 80.33batch/s, accuracy=100.0%, loss=0.0109]

 | Epoch 0 | Loss: 0.010750330053269863 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01247338391840458 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009970549494028091 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01303359866142273 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014356234110891819 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.017090853303670883 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014299770817160606 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010872330516576767 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012370619922876358 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1625228226184845 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010558297857642174 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012328902259469032 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011361292563378811 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18789446353912354 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.010979566723108292 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010989045724272728 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010407453402876854 | Accuracy: 100.0%


Epoch 0:  89%|████████▊ | 1330/1501 [00:13<00:01, 90.28batch/s, accuracy=100.0%, loss=0.00964]

 | Epoch 0 | Loss: 0.009219875559210777 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009688549675047398 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010648656636476517 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010015202686190605 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00982722919434309 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010690735653042793 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010000443086028099 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012611021287739277 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19225360453128815 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009376773610711098 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.014235636219382286 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009598970413208008 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009221888147294521 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009160593152046204 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009314514696598053 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012490027584135532 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009859615936875343 | Accuracy: 100.0

Epoch 0:  90%|█████████ | 1351/1501 [00:14<00:01, 95.56batch/s, accuracy=100.0%, loss=0.00804]

 | Epoch 0 | Loss: 0.008609842509031296 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008162355050444603 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00843783002346754 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009932261891663074 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0091763436794281 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007926344871520996 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01264702808111906 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008759546093642712 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008121468126773834 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009281936101615429 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010925371199846268 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00748814269900322 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007578830700367689 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007147005759179592 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007588385604321957 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008666559122502804 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007849366404116154 | Accuracy: 100.0%
 |

Epoch 0:  91%|█████████▏| 1371/1501 [00:14<00:01, 88.88batch/s, accuracy=100.0%, loss=0.00829]

 | Epoch 0 | Loss: 0.007501427084207535 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17709262669086456 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.009277724660933018 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007941416464745998 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0077177006751298904 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.17674127221107483 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.19481661915779114 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.1914054900407791 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008342555724084377 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007391938474029303 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009597249329090118 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008028252050280571 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3496728539466858 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008234897628426552 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.16591551899909973 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008292747661471367 | Accuracy: 100.0%


Epoch 0:  93%|█████████▎| 1390/1501 [00:14<00:01, 88.57batch/s, accuracy=100.0%, loss=0.0118]

 | Epoch 0 | Loss: 0.009418773464858532 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20077988505363464 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007917840033769608 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008511671796441078 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010045388713479042 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009241550229489803 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011386198922991753 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01035035029053688 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012471310794353485 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011397957801818848 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011952913366258144 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00922480970621109 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012075933627784252 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.010169222950935364 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012295462191104889 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2912680506706238 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.012511758133769035 | Accuracy: 100.0%


Epoch 0:  94%|█████████▍| 1408/1501 [00:14<00:01, 86.44batch/s, accuracy=100.0%, loss=0.00718]

 | Epoch 0 | Loss: 0.00953136570751667 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011660007759928703 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.011431201361119747 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009463991969823837 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00812677014619112 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008546210825443268 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008855807594954967 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.012709825299680233 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008361635729670525 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009621999226510525 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007036574650555849 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008471060544252396 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009188683703541756 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009018328972160816 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008673258125782013 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008442564867436886 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007714963052421808 | Accuracy: 100.0%

Epoch 0:  94%|█████████▍| 1417/1501 [00:15<00:00, 84.44batch/s, accuracy=100.0%, loss=0.00756]

 | Epoch 0 | Loss: 0.2516789138317108 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.006050293333828449 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007946082390844822 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1785064935684204 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.00881908554583788 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00739721255376935 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007244996260851622 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00854747649282217 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00792637374252081 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008428464643657207 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00747727882117033 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008994529023766518 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007387986872345209 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006344494875520468 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.1597803384065628 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008207900449633598 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009880278259515762 | Accuracy: 100.0%
 | Ep

Epoch 0:  96%|█████████▌| 1438/1501 [00:15<00:00, 92.04batch/s, accuracy=100.0%, loss=0.00676]

 | Epoch 0 | Loss: 0.007562993559986353 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008335118182003498 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0067670284770429134 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0079659940674901 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20991191267967224 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008920935913920403 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006536843720823526 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008994749747216702 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20000424981117249 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008983945474028587 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0084749236702919 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007301751524209976 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007727843709290028 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006469457410275936 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.005971809383481741 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.19413955509662628 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008668824099004269 | Accuracy: 100.0%


Epoch 0:  97%|█████████▋| 1457/1501 [00:15<00:00, 88.96batch/s, accuracy=100.0%, loss=0.00909]

 | Epoch 0 | Loss: 0.010989095084369183 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007008349522948265 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008271804079413414 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007891008630394936 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008141852915287018 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009746408089995384 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00808331836014986 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006408326793462038 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007199782878160477 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007323737721890211 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.009644054807722569 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007094762288033962 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007050633430480957 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007311912253499031 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00788840465247631 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007448329124599695 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3416397273540497 | Accuracy: 96.875%


Epoch 0:  98%|█████████▊| 1475/1501 [00:15<00:00, 88.83batch/s, accuracy=96.875%, loss=0.338] 

 | Epoch 0 | Loss: 0.0072719440795481205 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.26035594940185547 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007989767007529736 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006315710488706827 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006806556601077318 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.01055021584033966 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007032180670648813 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.006634590681642294 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.00845477543771267 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008881693705916405 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007432121317833662 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.3150797188282013 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.00736788846552372 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.20297545194625854 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007239250466227531 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007123616989701986 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008811364881694317 | Accuracy: 100.0%


Epoch 0: 100%|█████████▉| 1494/1501 [00:15<00:00, 85.76batch/s, accuracy=96.875%, loss=0.169] 

 | Epoch 0 | Loss: 0.011725349351763725 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007039968855679035 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.2041069120168686 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007489541079849005 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.0066990298219025135 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007772465236485004 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007507049944251776 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008330882526934147 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.18133826553821564 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.008947423659265041 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007181425113230944 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.6844380497932434 | Accuracy: 93.75%
 | Epoch 0 | Loss: 0.00842180848121643 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007551389280706644 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008337300270795822 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.007931418716907501 | Accuracy: 100.0%
 | Epoch 0 | Loss: 0.008958746679127216 | Accuracy: 100.0%


Epoch 0: 100%|█████████▉| 1494/1501 [00:15<00:00, 85.76batch/s, accuracy=100.0%, loss=0.00745]

 | Epoch 0 | Loss: 0.16851384937763214 | Accuracy: 96.875%
 | Epoch 0 | Loss: 0.007449684664607048 | Accuracy: 100.0%


Epoch 0: 100%|██████████| 1501/1501 [00:16<00:00, 92.72batch/s, accuracy=100.0%, loss=0.00745]


In [9]:
logits = trainer.get_trainset_outputs()
predictions = torch.nn.functional.softmax(logits, dim=1)

spare_infer = SpareInference(
    logits=predictions,
    class_labels=trainset.labels,
    device=device,
    max_clusters=5,
    high_sampling_power=params["high_sampling_power"],
    verbose=True
)

group_partition = spare_infer.infer_groups()
print(group_partition)
sampling_powers = spare_infer.sampling_powers

print("Sampling powers:", sampling_powers)
for key in sorted(group_partition.keys()):
    for true_key in sorted(trainset.group_partition.keys()):
        print(f"Inferred group: {key}, true group: {true_key}, size: {len([x for x in trainset.group_partition[true_key] if x in group_partition[key]])}")


Getting model outputs: 100%|██████████| 751/751 [00:06<00:00, 117.37batch/s]
Clustering class-wise:   0%|          | 0/5 [00:00<?, ?it/s]

For n_clusters = 2 The average silhouette_score is : 0.9946167
For n_clusters = 3 The average silhouette_score is : 0.9962789
For n_clusters = 4 The average silhouette_score is : 0.9974368


Clustering class-wise:  20%|██        | 1/5 [00:06<00:25,  6.41s/it]

For n_clusters = 5 The average silhouette_score is : 0.99831325
Silhouette score for class 2: 0.9983132481575012
For n_clusters = 2 The average silhouette_score is : 0.98960465
For n_clusters = 3 The average silhouette_score is : 0.9910108
For n_clusters = 4 The average silhouette_score is : 0.9923443


Clustering class-wise:  40%|████      | 2/5 [00:12<00:19,  6.47s/it]

For n_clusters = 5 The average silhouette_score is : 0.9933025
Silhouette score for class 0: 0.9933025240898132
For n_clusters = 2 The average silhouette_score is : 0.9924124
For n_clusters = 3 The average silhouette_score is : 0.9937549
For n_clusters = 4 The average silhouette_score is : 0.9950329


Clustering class-wise:  60%|██████    | 3/5 [00:17<00:11,  5.77s/it]

For n_clusters = 5 The average silhouette_score is : 0.99626327
Silhouette score for class 4: 0.9962632656097412
For n_clusters = 2 The average silhouette_score is : 0.9941025
For n_clusters = 3 The average silhouette_score is : 0.99562305
For n_clusters = 4 The average silhouette_score is : 0.99676114


Clustering class-wise:  80%|████████  | 4/5 [00:22<00:05,  5.50s/it]

For n_clusters = 5 The average silhouette_score is : 0.9978694
Silhouette score for class 1: 0.9978693723678589
For n_clusters = 2 The average silhouette_score is : 0.9919296
For n_clusters = 3 The average silhouette_score is : 0.993502
For n_clusters = 4 The average silhouette_score is : 0.9946918


Clustering class-wise: 100%|██████████| 5/5 [00:29<00:00,  5.90s/it]

For n_clusters = 5 The average silhouette_score is : 0.99562186
Silhouette score for class 3: 0.9956218600273132
{(2, 0): [0, 2, 9, 10, 17, 23, 32, 44, 50, 55, 57, 58, 77, 83, 95, 107, 111, 115, 118, 123, 133, 134, 136, 142, 144, 149, 157, 159, 176, 178, 189, 190, 192, 197, 207, 212, 215, 216, 217, 218, 220, 225, 228, 230, 232, 233, 236, 245, 248, 250, 263, 266, 267, 271, 287, 292, 295, 296, 299, 301, 302, 314, 319, 320, 326, 327, 328, 331, 334, 347, 348, 364, 366, 389, 394, 402, 403, 404, 405, 423, 428, 429, 448, 454, 462, 467, 482, 483, 486, 488, 490, 494, 505, 515, 519, 520, 526, 528, 542, 543, 544, 546, 551, 553, 558, 559, 560, 563, 572, 573, 574, 577, 581, 591, 593, 595, 596, 601, 612, 616, 619, 637, 639, 643, 644, 653, 665, 669, 670, 671, 672, 677, 680, 682, 683, 684, 685, 688, 704, 716, 719, 725, 731, 733, 734, 738, 742, 756, 758, 766, 769, 773, 776, 779, 792, 794, 795, 797, 806, 808, 823, 830, 834, 836, 846, 856, 864, 868, 875, 877, 880, 884, 885, 890, 894, 895, 896, 915, 917, 




Inferred group: (0, 0), true group: (0, 0), size: 10082
Inferred group: (0, 0), true group: (0, 1), size: 0
Inferred group: (0, 0), true group: (0, 2), size: 0
Inferred group: (0, 0), true group: (0, 3), size: 0
Inferred group: (0, 0), true group: (0, 4), size: 0
Inferred group: (0, 0), true group: (1, 0), size: 0
Inferred group: (0, 0), true group: (1, 1), size: 0
Inferred group: (0, 0), true group: (1, 2), size: 0
Inferred group: (0, 0), true group: (1, 3), size: 0
Inferred group: (0, 0), true group: (1, 4), size: 0
Inferred group: (0, 0), true group: (2, 0), size: 0
Inferred group: (0, 0), true group: (2, 1), size: 0
Inferred group: (0, 0), true group: (2, 2), size: 0
Inferred group: (0, 0), true group: (2, 3), size: 0
Inferred group: (0, 0), true group: (2, 4), size: 0
Inferred group: (0, 0), true group: (3, 0), size: 0
Inferred group: (0, 0), true group: (3, 1), size: 0
Inferred group: (0, 0), true group: (3, 2), size: 0
Inferred group: (0, 0), true group: (3, 3), size: 0
Inferred

In [10]:
valid_evaluator = Evaluator(
    testset=valset,
    group_partition=valset.group_partition,
    group_weights=valset.group_weights,
    batch_size=params["batch_size"],
    model=model,
    device=device,
    verbose=True
)

spare_train = SpareTrain(
    model=model,
    num_epochs=params["num_epochs"],
    trainset=trainset,
    group_partition=group_partition,
    sampling_powers=[20] * 5,
    batch_size=params["batch_size"],
    optimizer=SGD(model.parameters(), lr=params["lr"], weight_decay=params["weight_decay"], momentum=params["momentum"]),
    device=device,
    val_evaluator=valid_evaluator,
    verbose=True
)
spare_train.train()

Output hidden; open in https://colab.research.google.com to view.

In [12]:
evaluator = Evaluator(
    testset=testset,
    group_partition=testset.group_partition,
    group_weights=trainset.group_weights,
    batch_size=params["batch_size"],
    model=spare_train.best_model,
    device=device,
    verbose=True
)
evaluator.evaluate()

print("Final Results:")
print(f"Worst Group Accuracy: {evaluator.worst_group_accuracy[1]}")
print(f"Average Accuracy: {evaluator.average_accuracy}")

Evaluating group-wise accuracy:   4%|▍         | 1/25 [00:00<00:08,  2.85it/s]

Group (0, 0) Accuracy: 0.0


Evaluating group-wise accuracy:   8%|▊         | 2/25 [00:00<00:08,  2.82it/s]

Group (0, 1) Accuracy: 59.57446808510638


Evaluating group-wise accuracy:  12%|█▏        | 3/25 [00:01<00:07,  2.78it/s]

Group (0, 2) Accuracy: 85.1063829787234


Evaluating group-wise accuracy:  16%|█▌        | 4/25 [00:01<00:07,  2.82it/s]

Group (0, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  20%|██        | 5/25 [00:01<00:07,  2.84it/s]

Group (0, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  24%|██▍       | 6/25 [00:02<00:06,  2.84it/s]

Group (1, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  28%|██▊       | 7/25 [00:02<00:06,  2.88it/s]

Group (1, 1) Accuracy: 1.466992665036675


Evaluating group-wise accuracy:  32%|███▏      | 8/25 [00:02<00:06,  2.73it/s]

Group (1, 2) Accuracy: 77.94117647058823


Evaluating group-wise accuracy:  36%|███▌      | 9/25 [00:03<00:05,  2.80it/s]

Group (1, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  40%|████      | 10/25 [00:03<00:05,  2.82it/s]

Group (1, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  44%|████▍     | 11/25 [00:03<00:04,  2.87it/s]

Group (2, 0) Accuracy: 14.933333333333334


Evaluating group-wise accuracy:  48%|████▊     | 12/25 [00:04<00:04,  2.86it/s]

Group (2, 1) Accuracy: 0.0


Evaluating group-wise accuracy:  52%|█████▏    | 13/25 [00:04<00:04,  2.90it/s]

Group (2, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  56%|█████▌    | 14/25 [00:04<00:03,  2.93it/s]

Group (2, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  60%|██████    | 15/25 [00:05<00:03,  2.92it/s]

Group (2, 4) Accuracy: 100.0


Evaluating group-wise accuracy:  64%|██████▍   | 16/25 [00:05<00:03,  2.93it/s]

Group (3, 0) Accuracy: 100.0


Evaluating group-wise accuracy:  68%|██████▊   | 17/25 [00:06<00:03,  2.58it/s]

Group (3, 1) Accuracy: 0.5037783375314862


Evaluating group-wise accuracy:  72%|███████▏  | 18/25 [00:06<00:03,  2.20it/s]

Group (3, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  76%|███████▌  | 19/25 [00:07<00:02,  2.04it/s]

Group (3, 3) Accuracy: 0.0


Evaluating group-wise accuracy:  80%|████████  | 20/25 [00:07<00:02,  1.96it/s]

Group (3, 4) Accuracy: 0.0


Evaluating group-wise accuracy:  84%|████████▍ | 21/25 [00:08<00:02,  1.93it/s]

Group (4, 0) Accuracy: 0.0


Evaluating group-wise accuracy:  88%|████████▊ | 22/25 [00:08<00:01,  1.96it/s]

Group (4, 1) Accuracy: 87.6574307304786


Evaluating group-wise accuracy:  92%|█████████▏| 23/25 [00:09<00:00,  2.18it/s]

Group (4, 2) Accuracy: 0.0


Evaluating group-wise accuracy:  96%|█████████▌| 24/25 [00:09<00:00,  2.33it/s]

Group (4, 3) Accuracy: 99.4949494949495


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.53it/s]

Group (4, 4) Accuracy: 0.0
Final Results:
Worst Group Accuracy: 0.0
Average Accuracy: 0.42631064629556487



