In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [2]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='SIGMA', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
        self.LN2 = torch.nn.LayerNorm((args.conv_dim,14,14), elementwise_affine=False) 
        self.conv1 = SigmaConv(1, args.conv_dim, 3, args)
        self.conv2 = SigmaConv(args.conv_dim, 64, 3, args)
        self.view1 = SigmaView((64, 7, 7), 64 * 7 * 7)
        self.fc1 = SigmaLinear(64 * 7 * 7, 128, args)
        self.fc2 = SigmaLinear(128, 10, args)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        """
        Activations and signals are list of activations
        target is a onehot representation of target sigals
        """
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: MNIST_SIGMA_conv32-actelu_SGD_1_2023


  1%|▍                                          | 1/100 [00:08<14:15,  8.64s/it]

Epoch 0 | train loss 3.0550 | val loss 1.5330 | val acc 39.2250


  2%|▊                                          | 2/100 [00:17<14:16,  8.74s/it]

Epoch 1 | train loss 1.3119 | val loss 0.7975 | val acc 73.1333


  3%|█▎                                         | 3/100 [00:26<14:05,  8.72s/it]

Epoch 2 | train loss 0.5279 | val loss 0.1719 | val acc 95.2500


  4%|█▋                                         | 4/100 [00:34<13:57,  8.73s/it]

Epoch 3 | train loss 0.1722 | val loss 0.1242 | val acc 96.4500


  5%|██▏                                        | 5/100 [00:44<14:02,  8.87s/it]

Epoch 4 | train loss 0.1240 | val loss 0.1079 | val acc 96.7500


  6%|██▌                                        | 6/100 [00:53<14:15,  9.10s/it]

Epoch 5 | train loss 0.1062 | val loss 0.0967 | val acc 97.1083


  7%|███                                        | 7/100 [01:03<14:25,  9.30s/it]

Epoch 6 | train loss 0.0884 | val loss 0.0889 | val acc 97.2750


  8%|███▍                                       | 8/100 [01:12<14:06,  9.20s/it]

Epoch 7 | train loss 0.0800 | val loss 0.0876 | val acc 97.2250


  9%|███▊                                       | 9/100 [01:21<13:52,  9.14s/it]

Epoch 8 | train loss 0.0714 | val loss 0.0807 | val acc 97.4417


 10%|████▏                                     | 10/100 [01:30<13:33,  9.04s/it]

Epoch 9 | train loss 0.0738 | val loss 0.0783 | val acc 97.5583


 11%|████▌                                     | 11/100 [01:38<13:19,  8.98s/it]

Epoch 10 | train loss 0.0724 | val loss 0.0751 | val acc 97.6083


 12%|█████                                     | 12/100 [01:47<13:06,  8.94s/it]

Epoch 11 | train loss 0.0794 | val loss 0.0724 | val acc 97.7333


 13%|█████▍                                    | 13/100 [01:56<12:54,  8.90s/it]

Epoch 12 | train loss 0.0660 | val loss 0.0729 | val acc 97.6667


 14%|█████▉                                    | 14/100 [02:05<12:44,  8.89s/it]

Epoch 13 | train loss 0.0719 | val loss 0.0705 | val acc 97.7333


 15%|██████▎                                   | 15/100 [02:14<12:35,  8.89s/it]

Epoch 14 | train loss 0.0567 | val loss 0.0685 | val acc 97.8167


 16%|██████▋                                   | 16/100 [02:23<12:28,  8.91s/it]

Epoch 15 | train loss 0.0729 | val loss 0.0669 | val acc 97.9083


 17%|███████▏                                  | 17/100 [02:32<12:16,  8.88s/it]

Epoch 16 | train loss 0.0626 | val loss 0.0663 | val acc 97.8833


 18%|███████▌                                  | 18/100 [02:41<12:16,  8.98s/it]

Epoch 17 | train loss 0.0679 | val loss 0.0653 | val acc 97.8917


 19%|███████▉                                  | 19/100 [02:50<12:14,  9.07s/it]

Epoch 18 | train loss 0.0610 | val loss 0.0661 | val acc 97.8417


 20%|████████▍                                 | 20/100 [03:00<12:17,  9.22s/it]

Epoch 19 | train loss 0.0667 | val loss 0.0661 | val acc 97.8750


 21%|████████▊                                 | 21/100 [03:09<12:01,  9.13s/it]

Epoch 20 | train loss 0.0574 | val loss 0.0664 | val acc 97.8750


 22%|█████████▏                                | 22/100 [03:18<11:50,  9.11s/it]

Epoch 21 | train loss 0.0620 | val loss 0.0625 | val acc 97.9917


 23%|█████████▋                                | 23/100 [03:27<11:54,  9.28s/it]

Epoch 22 | train loss 0.0479 | val loss 0.0634 | val acc 98.0000


 24%|██████████                                | 24/100 [03:36<11:38,  9.19s/it]

Epoch 23 | train loss 0.0496 | val loss 0.0646 | val acc 98.0333


 25%|██████████▌                               | 25/100 [03:45<11:27,  9.17s/it]

Epoch 24 | train loss 0.0516 | val loss 0.0642 | val acc 97.9833


 26%|██████████▉                               | 26/100 [03:55<11:26,  9.28s/it]

Epoch 25 | train loss 0.0492 | val loss 0.0626 | val acc 97.9750


 27%|███████████▎                              | 27/100 [04:04<11:17,  9.28s/it]

Epoch 26 | train loss 0.0587 | val loss 0.0632 | val acc 97.9833


 28%|███████████▊                              | 28/100 [04:13<11:02,  9.21s/it]

Epoch 27 | train loss 0.0452 | val loss 0.0645 | val acc 98.0000


 29%|████████████▏                             | 29/100 [04:22<10:52,  9.19s/it]

Epoch 28 | train loss 0.0516 | val loss 0.0619 | val acc 98.1250


 30%|████████████▌                             | 30/100 [04:32<10:46,  9.24s/it]

Epoch 29 | train loss 0.0581 | val loss 0.0639 | val acc 98.0083


 31%|█████████████                             | 31/100 [04:41<10:32,  9.17s/it]

Epoch 30 | train loss 0.0463 | val loss 0.0608 | val acc 98.1000


 32%|█████████████▍                            | 32/100 [04:50<10:20,  9.13s/it]

Epoch 31 | train loss 0.0544 | val loss 0.0630 | val acc 98.0167


 33%|█████████████▊                            | 33/100 [04:59<10:16,  9.21s/it]

Epoch 32 | train loss 0.0535 | val loss 0.0586 | val acc 98.2417


 34%|██████████████▎                           | 34/100 [05:08<10:06,  9.20s/it]

Epoch 33 | train loss 0.0464 | val loss 0.0606 | val acc 98.1583


 35%|██████████████▋                           | 35/100 [05:17<09:51,  9.10s/it]

Epoch 34 | train loss 0.0535 | val loss 0.0591 | val acc 98.1083


 36%|███████████████                           | 36/100 [05:26<09:40,  9.07s/it]

Epoch 35 | train loss 0.0480 | val loss 0.0592 | val acc 98.1500


 37%|███████████████▌                          | 37/100 [05:35<09:31,  9.08s/it]

Epoch 36 | train loss 0.0454 | val loss 0.0582 | val acc 98.2083


 38%|███████████████▉                          | 38/100 [05:44<09:22,  9.07s/it]

Epoch 37 | train loss 0.0438 | val loss 0.0596 | val acc 98.0833


 39%|████████████████▍                         | 39/100 [05:53<09:11,  9.04s/it]

Epoch 38 | train loss 0.0526 | val loss 0.0573 | val acc 98.1417


 40%|████████████████▊                         | 40/100 [06:03<09:11,  9.20s/it]

Epoch 39 | train loss 0.0506 | val loss 0.0605 | val acc 98.1500


 41%|█████████████████▏                        | 41/100 [06:12<08:59,  9.15s/it]

Epoch 40 | train loss 0.0528 | val loss 0.0571 | val acc 98.2333


 42%|█████████████████▋                        | 42/100 [06:21<08:44,  9.05s/it]

Epoch 41 | train loss 0.0500 | val loss 0.0580 | val acc 98.2417


 43%|██████████████████                        | 43/100 [06:30<08:30,  8.96s/it]

Epoch 42 | train loss 0.0439 | val loss 0.0587 | val acc 98.1750


 44%|██████████████████▍                       | 44/100 [06:38<08:18,  8.90s/it]

Epoch 43 | train loss 0.0512 | val loss 0.0606 | val acc 98.1417


 45%|██████████████████▉                       | 45/100 [06:47<08:07,  8.86s/it]

Epoch 44 | train loss 0.0467 | val loss 0.0591 | val acc 98.1083


 46%|███████████████████▎                      | 46/100 [06:56<08:00,  8.90s/it]

Epoch 45 | train loss 0.0470 | val loss 0.0595 | val acc 98.2167


 47%|███████████████████▋                      | 47/100 [07:05<07:54,  8.95s/it]

Epoch 46 | train loss 0.0317 | val loss 0.0599 | val acc 98.1917


 48%|████████████████████▏                     | 48/100 [07:14<07:44,  8.93s/it]

Epoch 47 | train loss 0.0459 | val loss 0.0589 | val acc 98.2083


 49%|████████████████████▌                     | 49/100 [07:23<07:36,  8.94s/it]

Epoch 48 | train loss 0.0438 | val loss 0.0584 | val acc 98.0583


 50%|█████████████████████                     | 50/100 [07:32<07:26,  8.93s/it]

Epoch 49 | train loss 0.0468 | val loss 0.0601 | val acc 98.1250


 51%|█████████████████████▍                    | 51/100 [07:41<07:18,  8.96s/it]

Epoch 50 | train loss 0.0422 | val loss 0.0586 | val acc 98.1917


 52%|█████████████████████▊                    | 52/100 [07:50<07:13,  9.02s/it]

Epoch 51 | train loss 0.0406 | val loss 0.0577 | val acc 98.2000


 53%|██████████████████████▎                   | 53/100 [07:59<07:05,  9.06s/it]

Epoch 52 | train loss 0.0434 | val loss 0.0570 | val acc 98.2167


 54%|██████████████████████▋                   | 54/100 [08:08<06:54,  9.01s/it]

Epoch 53 | train loss 0.0442 | val loss 0.0557 | val acc 98.2667


 55%|███████████████████████                   | 55/100 [08:17<06:45,  9.00s/it]

Epoch 54 | train loss 0.0491 | val loss 0.0564 | val acc 98.1500


 56%|███████████████████████▌                  | 56/100 [08:26<06:37,  9.04s/it]

Epoch 55 | train loss 0.0426 | val loss 0.0567 | val acc 98.1750


 57%|███████████████████████▉                  | 57/100 [08:35<06:30,  9.09s/it]

Epoch 56 | train loss 0.0408 | val loss 0.0550 | val acc 98.2583


 58%|████████████████████████▎                 | 58/100 [08:45<06:23,  9.12s/it]

Epoch 57 | train loss 0.0444 | val loss 0.0534 | val acc 98.3083


 59%|████████████████████████▊                 | 59/100 [08:54<06:14,  9.14s/it]

Epoch 58 | train loss 0.0339 | val loss 0.0551 | val acc 98.2833


 60%|█████████████████████████▏                | 60/100 [09:03<06:07,  9.18s/it]

Epoch 59 | train loss 0.0388 | val loss 0.0543 | val acc 98.2333


 61%|█████████████████████████▌                | 61/100 [09:12<05:57,  9.16s/it]

Epoch 60 | train loss 0.0435 | val loss 0.0533 | val acc 98.3250


 62%|██████████████████████████                | 62/100 [09:22<05:52,  9.27s/it]

Epoch 61 | train loss 0.0396 | val loss 0.0555 | val acc 98.2000


 63%|██████████████████████████▍               | 63/100 [09:32<05:49,  9.43s/it]

Epoch 62 | train loss 0.0359 | val loss 0.0534 | val acc 98.3250


 64%|██████████████████████████▉               | 64/100 [09:41<05:36,  9.36s/it]

Epoch 63 | train loss 0.0345 | val loss 0.0547 | val acc 98.2833


 65%|███████████████████████████▎              | 65/100 [09:50<05:27,  9.37s/it]

Epoch 64 | train loss 0.0419 | val loss 0.0535 | val acc 98.2583


 66%|███████████████████████████▋              | 66/100 [10:00<05:20,  9.42s/it]

Epoch 65 | train loss 0.0381 | val loss 0.0531 | val acc 98.3000


 67%|████████████████████████████▏             | 67/100 [10:09<05:06,  9.30s/it]

Epoch 66 | train loss 0.0360 | val loss 0.0532 | val acc 98.2417


 68%|████████████████████████████▌             | 68/100 [10:17<04:51,  9.10s/it]

Epoch 67 | train loss 0.0411 | val loss 0.0544 | val acc 98.2667


 69%|████████████████████████████▉             | 69/100 [10:26<04:39,  9.00s/it]

Epoch 68 | train loss 0.0440 | val loss 0.0525 | val acc 98.3417


 70%|█████████████████████████████▍            | 70/100 [10:36<04:35,  9.17s/it]

Epoch 69 | train loss 0.0384 | val loss 0.0521 | val acc 98.3250


 71%|█████████████████████████████▊            | 71/100 [10:45<04:27,  9.21s/it]

Epoch 70 | train loss 0.0361 | val loss 0.0530 | val acc 98.3500


 72%|██████████████████████████████▏           | 72/100 [10:54<04:19,  9.27s/it]

Epoch 71 | train loss 0.0350 | val loss 0.0538 | val acc 98.2167


 73%|██████████████████████████████▋           | 73/100 [11:04<04:10,  9.28s/it]

Epoch 72 | train loss 0.0330 | val loss 0.0524 | val acc 98.3333


 74%|███████████████████████████████           | 74/100 [11:13<04:00,  9.26s/it]

Epoch 73 | train loss 0.0459 | val loss 0.0512 | val acc 98.3083


 75%|███████████████████████████████▌          | 75/100 [11:22<03:51,  9.26s/it]

Epoch 74 | train loss 0.0397 | val loss 0.0531 | val acc 98.2667


 76%|███████████████████████████████▉          | 76/100 [11:31<03:41,  9.21s/it]

Epoch 75 | train loss 0.0373 | val loss 0.0551 | val acc 98.2417


 77%|████████████████████████████████▎         | 77/100 [11:40<03:30,  9.17s/it]

Epoch 76 | train loss 0.0345 | val loss 0.0546 | val acc 98.2333


 78%|████████████████████████████████▊         | 78/100 [11:50<03:22,  9.20s/it]

Epoch 77 | train loss 0.0365 | val loss 0.0536 | val acc 98.3000


 79%|█████████████████████████████████▏        | 79/100 [11:59<03:11,  9.14s/it]

Epoch 78 | train loss 0.0353 | val loss 0.0529 | val acc 98.3583


 80%|█████████████████████████████████▌        | 80/100 [12:07<03:00,  9.04s/it]

Epoch 79 | train loss 0.0298 | val loss 0.0526 | val acc 98.3833


 81%|██████████████████████████████████        | 81/100 [12:16<02:50,  8.97s/it]

Epoch 80 | train loss 0.0369 | val loss 0.0503 | val acc 98.4083


 82%|██████████████████████████████████▍       | 82/100 [12:26<02:43,  9.09s/it]

Epoch 81 | train loss 0.0429 | val loss 0.0530 | val acc 98.2917


 83%|██████████████████████████████████▊       | 83/100 [12:35<02:36,  9.23s/it]

Epoch 82 | train loss 0.0325 | val loss 0.0520 | val acc 98.4000


 84%|███████████████████████████████████▎      | 84/100 [12:44<02:26,  9.17s/it]

Epoch 83 | train loss 0.0391 | val loss 0.0520 | val acc 98.3583


 85%|███████████████████████████████████▋      | 85/100 [12:53<02:18,  9.21s/it]

Epoch 84 | train loss 0.0306 | val loss 0.0507 | val acc 98.4250


 86%|████████████████████████████████████      | 86/100 [13:03<02:09,  9.24s/it]

Epoch 85 | train loss 0.0385 | val loss 0.0501 | val acc 98.3667


 87%|████████████████████████████████████▌     | 87/100 [13:12<02:00,  9.25s/it]

Epoch 86 | train loss 0.0361 | val loss 0.0489 | val acc 98.4833


 88%|████████████████████████████████████▉     | 88/100 [13:21<01:51,  9.25s/it]

Epoch 87 | train loss 0.0428 | val loss 0.0485 | val acc 98.4917


 89%|█████████████████████████████████████▍    | 89/100 [13:30<01:40,  9.18s/it]

Epoch 88 | train loss 0.0303 | val loss 0.0506 | val acc 98.4667


 90%|█████████████████████████████████████▊    | 90/100 [13:40<01:31,  9.19s/it]

Epoch 89 | train loss 0.0390 | val loss 0.0502 | val acc 98.4417


 91%|██████████████████████████████████████▏   | 91/100 [13:49<01:23,  9.23s/it]

Epoch 90 | train loss 0.0335 | val loss 0.0502 | val acc 98.4333


 92%|██████████████████████████████████████▋   | 92/100 [13:58<01:13,  9.23s/it]

Epoch 91 | train loss 0.0368 | val loss 0.0490 | val acc 98.4417


 93%|███████████████████████████████████████   | 93/100 [14:07<01:03,  9.10s/it]

Epoch 92 | train loss 0.0384 | val loss 0.0490 | val acc 98.4500


 94%|███████████████████████████████████████▍  | 94/100 [14:16<00:54,  9.05s/it]

Epoch 93 | train loss 0.0398 | val loss 0.0504 | val acc 98.3167


 95%|███████████████████████████████████████▉  | 95/100 [14:25<00:44,  8.98s/it]

Epoch 94 | train loss 0.0331 | val loss 0.0502 | val acc 98.3750


 96%|████████████████████████████████████████▎ | 96/100 [14:33<00:35,  8.93s/it]

Epoch 95 | train loss 0.0341 | val loss 0.0510 | val acc 98.3917


 97%|████████████████████████████████████████▋ | 97/100 [14:42<00:26,  8.94s/it]

Epoch 96 | train loss 0.0313 | val loss 0.0498 | val acc 98.4500


 98%|█████████████████████████████████████████▏| 98/100 [14:52<00:18,  9.01s/it]

Epoch 97 | train loss 0.0295 | val loss 0.0516 | val acc 98.4167


 99%|█████████████████████████████████████████▌| 99/100 [15:01<00:09,  9.09s/it]

Epoch 98 | train loss 0.0357 | val loss 0.0493 | val acc 98.4167


100%|█████████████████████████████████████████| 100/100 [15:10<00:00,  9.10s/it]

Epoch 99 | train loss 0.0271 | val loss 0.0515 | val acc 98.4417





AttributeError: 'list' object has no attribute 'data'