In [2]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [3]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='BP', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=0.1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
        self.LN2 = torch.nn.LayerNorm((args.conv_dim,14,14), elementwise_affine=False) 
        self.conv1 = SigmaConv(1, args.conv_dim, 3, args)
        self.conv2 = SigmaConv(args.conv_dim, 64, 3, args)
        self.view1 = SigmaView((64, 7, 7), 64 * 7 * 7)
        self.fc1 = SigmaLinear(64 * 7 * 7, 128, args)
        self.fc2 = SigmaLinear(128, 10, args)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        """
        Activations and signals are list of activations
        target is a onehot representation of target sigals
        """
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: MNIST_BP_conv32-actelu_SGD_0.1_2023


  1%|▍                                          | 1/100 [00:09<15:16,  9.26s/it]

Epoch 0 | train loss 0.4976 | val loss 0.1279 | val acc 95.9417


  2%|▊                                          | 2/100 [00:18<15:08,  9.27s/it]

Epoch 1 | train loss 0.1507 | val loss 0.0895 | val acc 97.3417


  3%|█▎                                         | 3/100 [00:27<14:52,  9.20s/it]

Epoch 2 | train loss 0.0968 | val loss 0.0731 | val acc 97.6750


  4%|█▋                                         | 4/100 [00:36<14:36,  9.13s/it]

Epoch 3 | train loss 0.0941 | val loss 0.0599 | val acc 98.1833


  5%|██▏                                        | 5/100 [00:45<14:21,  9.07s/it]

Epoch 4 | train loss 0.0714 | val loss 0.0599 | val acc 98.1500


  6%|██▌                                        | 6/100 [00:54<14:14,  9.09s/it]

Epoch 5 | train loss 0.0694 | val loss 0.0524 | val acc 98.4167


  7%|███                                        | 7/100 [01:03<14:01,  9.04s/it]

Epoch 6 | train loss 0.0533 | val loss 0.0509 | val acc 98.5333


  8%|███▍                                       | 8/100 [01:12<13:43,  8.96s/it]

Epoch 7 | train loss 0.0425 | val loss 0.0491 | val acc 98.4833


  9%|███▊                                       | 9/100 [01:21<13:28,  8.88s/it]

Epoch 8 | train loss 0.0411 | val loss 0.0489 | val acc 98.5167


 10%|████▏                                     | 10/100 [01:29<13:14,  8.82s/it]

Epoch 9 | train loss 0.0490 | val loss 0.0452 | val acc 98.6000


 11%|████▌                                     | 11/100 [01:38<12:59,  8.76s/it]

Epoch 10 | train loss 0.0437 | val loss 0.0451 | val acc 98.5583


 12%|█████                                     | 12/100 [01:47<12:55,  8.81s/it]

Epoch 11 | train loss 0.0397 | val loss 0.0450 | val acc 98.5750


 13%|█████▍                                    | 13/100 [01:56<12:54,  8.91s/it]

Epoch 12 | train loss 0.0368 | val loss 0.0441 | val acc 98.6583


 14%|█████▉                                    | 14/100 [02:05<12:48,  8.93s/it]

Epoch 13 | train loss 0.0378 | val loss 0.0450 | val acc 98.6083


 15%|██████▎                                   | 15/100 [02:14<12:36,  8.90s/it]

Epoch 14 | train loss 0.0248 | val loss 0.0424 | val acc 98.7083


 16%|██████▋                                   | 16/100 [02:23<12:25,  8.87s/it]

Epoch 15 | train loss 0.0377 | val loss 0.0399 | val acc 98.7833


 17%|███████▏                                  | 17/100 [02:32<12:16,  8.88s/it]

Epoch 16 | train loss 0.0289 | val loss 0.0407 | val acc 98.8750


 18%|███████▌                                  | 18/100 [02:40<12:00,  8.79s/it]

Epoch 17 | train loss 0.0286 | val loss 0.0390 | val acc 98.7833


 19%|███████▉                                  | 19/100 [02:49<11:48,  8.75s/it]

Epoch 18 | train loss 0.0258 | val loss 0.0393 | val acc 98.8583


 20%|████████▍                                 | 20/100 [02:57<11:34,  8.69s/it]

Epoch 19 | train loss 0.0270 | val loss 0.0421 | val acc 98.7917


 21%|████████▊                                 | 21/100 [03:06<11:24,  8.67s/it]

Epoch 20 | train loss 0.0224 | val loss 0.0377 | val acc 98.9917


 22%|█████████▏                                | 22/100 [03:15<11:18,  8.69s/it]

Epoch 21 | train loss 0.0265 | val loss 0.0402 | val acc 98.8333


 23%|█████████▋                                | 23/100 [03:24<11:15,  8.77s/it]

Epoch 22 | train loss 0.0171 | val loss 0.0406 | val acc 98.9667


 24%|██████████                                | 24/100 [03:33<11:08,  8.79s/it]

Epoch 23 | train loss 0.0192 | val loss 0.0392 | val acc 98.9417


 25%|██████████▌                               | 25/100 [03:41<10:58,  8.79s/it]

Epoch 24 | train loss 0.0178 | val loss 0.0372 | val acc 98.9917


 26%|██████████▉                               | 26/100 [03:50<10:52,  8.82s/it]

Epoch 25 | train loss 0.0228 | val loss 0.0406 | val acc 98.8667


 27%|███████████▎                              | 27/100 [03:59<10:52,  8.94s/it]

Epoch 26 | train loss 0.0218 | val loss 0.0435 | val acc 98.7750


 28%|███████████▊                              | 28/100 [04:09<10:51,  9.04s/it]

Epoch 27 | train loss 0.0142 | val loss 0.0400 | val acc 98.9417


 29%|████████████▏                             | 29/100 [04:18<10:50,  9.16s/it]

Epoch 28 | train loss 0.0191 | val loss 0.0406 | val acc 98.9000


 30%|████████████▌                             | 30/100 [04:27<10:39,  9.14s/it]

Epoch 29 | train loss 0.0195 | val loss 0.0399 | val acc 98.9000


 31%|█████████████                             | 31/100 [04:36<10:29,  9.12s/it]

Epoch 30 | train loss 0.0143 | val loss 0.0466 | val acc 98.7417


 32%|█████████████▍                            | 32/100 [04:45<10:15,  9.06s/it]

Epoch 31 | train loss 0.0168 | val loss 0.0403 | val acc 98.9250


 33%|█████████████▊                            | 33/100 [04:54<10:00,  8.97s/it]

Epoch 32 | train loss 0.0126 | val loss 0.0386 | val acc 99.0500


 34%|██████████████▎                           | 34/100 [05:03<09:45,  8.87s/it]

Epoch 33 | train loss 0.0148 | val loss 0.0389 | val acc 99.0083


 35%|██████████████▋                           | 35/100 [05:11<09:34,  8.84s/it]

Epoch 34 | train loss 0.0196 | val loss 0.0392 | val acc 98.9667


 36%|███████████████                           | 36/100 [05:20<09:24,  8.82s/it]

Epoch 35 | train loss 0.0155 | val loss 0.0377 | val acc 99.0083


 37%|███████████████▌                          | 37/100 [05:29<09:16,  8.84s/it]

Epoch 36 | train loss 0.0077 | val loss 0.0394 | val acc 98.9500


 38%|███████████████▉                          | 38/100 [05:38<09:06,  8.81s/it]

Epoch 37 | train loss 0.0128 | val loss 0.0393 | val acc 99.0000


 39%|████████████████▍                         | 39/100 [05:46<08:53,  8.75s/it]

Epoch 38 | train loss 0.0186 | val loss 0.0392 | val acc 98.9417


 40%|████████████████▊                         | 40/100 [05:55<08:44,  8.74s/it]

Epoch 39 | train loss 0.0093 | val loss 0.0383 | val acc 98.9750


 41%|█████████████████▏                        | 41/100 [06:04<08:34,  8.72s/it]

Epoch 40 | train loss 0.0092 | val loss 0.0415 | val acc 98.9750


 42%|█████████████████▋                        | 42/100 [06:12<08:25,  8.72s/it]

Epoch 41 | train loss 0.0112 | val loss 0.0394 | val acc 99.0333


 43%|██████████████████                        | 43/100 [06:22<08:23,  8.83s/it]

Epoch 42 | train loss 0.0091 | val loss 0.0374 | val acc 99.0000


 44%|██████████████████▍                       | 44/100 [06:31<08:17,  8.89s/it]

Epoch 43 | train loss 0.0097 | val loss 0.0390 | val acc 99.0000


 45%|██████████████████▉                       | 45/100 [06:40<08:11,  8.93s/it]

Epoch 44 | train loss 0.0108 | val loss 0.0413 | val acc 98.9750


 46%|███████████████████▎                      | 46/100 [06:48<07:59,  8.88s/it]

Epoch 45 | train loss 0.0102 | val loss 0.0430 | val acc 98.9250


 47%|███████████████████▋                      | 47/100 [06:57<07:51,  8.90s/it]

Epoch 46 | train loss 0.0033 | val loss 0.0373 | val acc 99.0750


 48%|████████████████████▏                     | 48/100 [07:06<07:42,  8.90s/it]

Epoch 47 | train loss 0.0066 | val loss 0.0377 | val acc 99.0833


 49%|████████████████████▌                     | 49/100 [07:15<07:32,  8.88s/it]

Epoch 48 | train loss 0.0033 | val loss 0.0406 | val acc 98.9500


 50%|█████████████████████                     | 50/100 [07:24<07:21,  8.84s/it]

Epoch 49 | train loss 0.0060 | val loss 0.0392 | val acc 99.0000


 51%|█████████████████████▍                    | 51/100 [07:32<07:10,  8.78s/it]

Epoch 50 | train loss 0.0083 | val loss 0.0385 | val acc 98.9917


 52%|█████████████████████▊                    | 52/100 [07:41<07:00,  8.76s/it]

Epoch 51 | train loss 0.0142 | val loss 0.0388 | val acc 98.9667


 53%|██████████████████████▎                   | 53/100 [07:50<06:50,  8.72s/it]

Epoch 52 | train loss 0.0079 | val loss 0.0391 | val acc 98.9750


 54%|██████████████████████▋                   | 54/100 [07:59<06:41,  8.72s/it]

Epoch 53 | train loss 0.0074 | val loss 0.0372 | val acc 99.0833


 55%|███████████████████████                   | 55/100 [08:07<06:32,  8.73s/it]

Epoch 54 | train loss 0.0091 | val loss 0.0438 | val acc 98.9583


 56%|███████████████████████▌                  | 56/100 [08:16<06:23,  8.71s/it]

Epoch 55 | train loss 0.0214 | val loss 0.0558 | val acc 98.7333


 57%|███████████████████████▉                  | 57/100 [08:25<06:17,  8.77s/it]

Epoch 56 | train loss 0.0154 | val loss 0.0491 | val acc 98.9167


 58%|████████████████████████▎                 | 58/100 [08:35<06:19,  9.04s/it]

Epoch 57 | train loss 0.0040 | val loss 0.0411 | val acc 99.0750


 59%|████████████████████████▊                 | 59/100 [08:43<06:09,  9.02s/it]

Epoch 58 | train loss 0.0048 | val loss 0.0473 | val acc 98.8667


 60%|█████████████████████████▏                | 60/100 [08:53<06:05,  9.14s/it]

Epoch 59 | train loss 0.0070 | val loss 0.0438 | val acc 99.0000


 61%|█████████████████████████▌                | 61/100 [09:02<05:56,  9.15s/it]

Epoch 60 | train loss 0.0022 | val loss 0.0427 | val acc 99.0833


 62%|██████████████████████████                | 62/100 [09:11<05:44,  9.08s/it]

Epoch 61 | train loss 0.0045 | val loss 0.0418 | val acc 99.0583


 63%|██████████████████████████▍               | 63/100 [09:20<05:35,  9.06s/it]

Epoch 62 | train loss 0.0048 | val loss 0.0466 | val acc 99.0167


 64%|██████████████████████████▉               | 64/100 [09:29<05:23,  8.99s/it]

Epoch 63 | train loss 0.0107 | val loss 0.0488 | val acc 98.8583


 65%|███████████████████████████▎              | 65/100 [09:38<05:15,  9.01s/it]

Epoch 64 | train loss 0.0049 | val loss 0.0435 | val acc 98.9250


 66%|███████████████████████████▋              | 66/100 [09:47<05:05,  9.00s/it]

Epoch 65 | train loss 0.0071 | val loss 0.0451 | val acc 98.9667


 67%|████████████████████████████▏             | 67/100 [09:56<04:57,  9.01s/it]

Epoch 66 | train loss 0.0062 | val loss 0.0494 | val acc 98.9333


 68%|████████████████████████████▌             | 68/100 [10:05<04:50,  9.08s/it]

Epoch 67 | train loss 0.0056 | val loss 0.0440 | val acc 98.9667


 69%|████████████████████████████▉             | 69/100 [10:14<04:43,  9.15s/it]

Epoch 68 | train loss 0.0069 | val loss 0.0445 | val acc 99.0000


 70%|█████████████████████████████▍            | 70/100 [10:23<04:32,  9.08s/it]

Epoch 69 | train loss 0.0029 | val loss 0.0416 | val acc 99.0333


 71%|█████████████████████████████▊            | 71/100 [10:32<04:23,  9.09s/it]

Epoch 70 | train loss 0.0045 | val loss 0.0464 | val acc 99.0417


 72%|██████████████████████████████▏           | 72/100 [10:42<04:14,  9.11s/it]

Epoch 71 | train loss 0.0039 | val loss 0.0472 | val acc 99.0500


 73%|██████████████████████████████▋           | 73/100 [10:51<04:08,  9.20s/it]

Epoch 72 | train loss 0.0041 | val loss 0.0507 | val acc 98.9000


 74%|███████████████████████████████           | 74/100 [11:00<03:58,  9.18s/it]

Epoch 73 | train loss 0.0092 | val loss 0.0492 | val acc 98.9417


 75%|███████████████████████████████▌          | 75/100 [11:09<03:50,  9.20s/it]

Epoch 74 | train loss 0.0044 | val loss 0.0475 | val acc 99.0000


 76%|███████████████████████████████▉          | 76/100 [11:18<03:39,  9.15s/it]

Epoch 75 | train loss 0.0057 | val loss 0.0518 | val acc 98.8750


 77%|████████████████████████████████▎         | 77/100 [11:27<03:27,  9.03s/it]

Epoch 76 | train loss 0.0018 | val loss 0.0480 | val acc 98.9917


 78%|████████████████████████████████▊         | 78/100 [11:36<03:17,  8.96s/it]

Epoch 77 | train loss 0.0017 | val loss 0.0468 | val acc 99.0833


 79%|█████████████████████████████████▏        | 79/100 [11:45<03:06,  8.90s/it]

Epoch 78 | train loss 0.0037 | val loss 0.0487 | val acc 98.9833


 80%|█████████████████████████████████▌        | 80/100 [11:54<02:57,  8.86s/it]

Epoch 79 | train loss 0.0007 | val loss 0.0485 | val acc 99.0500


 81%|██████████████████████████████████        | 81/100 [12:02<02:47,  8.83s/it]

Epoch 80 | train loss 0.0008 | val loss 0.0486 | val acc 99.0667


 82%|██████████████████████████████████▍       | 82/100 [12:11<02:38,  8.83s/it]

Epoch 81 | train loss 0.0021 | val loss 0.0496 | val acc 99.0333


 83%|██████████████████████████████████▊       | 83/100 [12:20<02:29,  8.78s/it]

Epoch 82 | train loss 0.0024 | val loss 0.0472 | val acc 99.0333


 84%|███████████████████████████████████▎      | 84/100 [12:29<02:20,  8.77s/it]

Epoch 83 | train loss 0.0035 | val loss 0.0501 | val acc 99.0750


 85%|███████████████████████████████████▋      | 85/100 [12:38<02:12,  8.86s/it]

Epoch 84 | train loss 0.0021 | val loss 0.0448 | val acc 99.0917


 86%|████████████████████████████████████      | 86/100 [12:47<02:06,  9.07s/it]

Epoch 85 | train loss 0.0016 | val loss 0.0463 | val acc 99.1000


 87%|████████████████████████████████████▌     | 87/100 [12:57<01:59,  9.18s/it]

Epoch 86 | train loss 0.0017 | val loss 0.0464 | val acc 99.0833


 88%|████████████████████████████████████▉     | 88/100 [13:06<01:50,  9.22s/it]

Epoch 87 | train loss 0.0026 | val loss 0.0452 | val acc 99.1167


 89%|█████████████████████████████████████▍    | 89/100 [13:15<01:41,  9.26s/it]

Epoch 88 | train loss 0.0010 | val loss 0.0451 | val acc 99.1583


 90%|█████████████████████████████████████▊    | 90/100 [13:25<01:34,  9.44s/it]

Epoch 89 | train loss 0.0011 | val loss 0.0456 | val acc 99.0917


 91%|██████████████████████████████████████▏   | 91/100 [13:34<01:23,  9.27s/it]

Epoch 90 | train loss 0.0006 | val loss 0.0449 | val acc 99.0917


 92%|██████████████████████████████████████▋   | 92/100 [13:44<01:15,  9.40s/it]

Epoch 91 | train loss 0.0003 | val loss 0.0438 | val acc 99.1083


 93%|███████████████████████████████████████   | 93/100 [13:53<01:05,  9.35s/it]

Epoch 92 | train loss 0.0004 | val loss 0.0436 | val acc 99.1333


 94%|███████████████████████████████████████▍  | 94/100 [14:02<00:55,  9.32s/it]

Epoch 93 | train loss 0.0003 | val loss 0.0442 | val acc 99.1167


 95%|███████████████████████████████████████▉  | 95/100 [14:11<00:46,  9.30s/it]

Epoch 94 | train loss 0.0006 | val loss 0.0451 | val acc 99.0833


 96%|████████████████████████████████████████▎ | 96/100 [14:21<00:37,  9.45s/it]

Epoch 95 | train loss 0.0001 | val loss 0.0450 | val acc 99.1000


 97%|████████████████████████████████████████▋ | 97/100 [14:31<00:28,  9.59s/it]

Epoch 96 | train loss 0.0003 | val loss 0.0438 | val acc 99.1167


 98%|█████████████████████████████████████████▏| 98/100 [14:40<00:18,  9.49s/it]

Epoch 97 | train loss 0.0002 | val loss 0.0449 | val acc 99.1250


 99%|█████████████████████████████████████████▌| 99/100 [14:50<00:09,  9.43s/it]

Epoch 98 | train loss 0.0002 | val loss 0.0446 | val acc 99.1167


100%|█████████████████████████████████████████| 100/100 [14:59<00:00,  8.99s/it]

Epoch 99 | train loss 0.0001 | val loss 0.0446 | val acc 99.1250





AttributeError: 'list' object has no attribute 'data'