In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [4]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='BP', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--architecture', type=str, default='lenet', choices=['lenet', 'vgg'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=0.1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.LN1 = torch.nn.LayerNorm((3,32,32), elementwise_affine=False) 
            self.conv1 = SigmaConv(args, 3, 32, 5, 3, 2)
            self.conv2 = SigmaConv(args, 32, 64, 3, 3, 2)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(args, 64 * 8 * 8, 128)
            self.fc2 = SigmaLinear(args, 128, 10)
            
        elif args.dataset == "MNIST":
            if args.architecture == "lenet":
                self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
                self.conv1 = SigmaConv(args, 1, 32, 5, 2, 3)
                self.conv2 = SigmaConv(args, 32, 64, 5, 2, 3)
                self.view1 = SigmaView((64, 4, 4), 64 * 4 * 4)
                self.fc1 = SigmaLinear(args, 64 * 4 * 4, 512)
                self.fc2 = SigmaLinear(args, 512, 10)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        # x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target) * 1.1 - 0.1
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
           'test_acc': 100 * correct / test_counter})

print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: CIFAR10_BP_conv32-actelu_SGD_0.1_2023
Files already downloaded and verified
Files already downloaded and verified
No LN1


  1%|          | 1/100 [00:12<20:54, 12.67s/it]

Epoch 0 | train loss 1.8898 | val loss 1.5310 | val acc 48.4800


  2%|▏         | 2/100 [00:25<20:54, 12.80s/it]

Epoch 1 | train loss 1.5271 | val loss 1.3117 | val acc 54.6800


  3%|▎         | 3/100 [00:38<20:47, 12.86s/it]

Epoch 2 | train loss 1.4092 | val loss 1.1607 | val acc 60.3000


  4%|▍         | 4/100 [00:51<20:37, 12.89s/it]

Epoch 3 | train loss 1.2614 | val loss 1.0909 | val acc 62.8100


  5%|▌         | 5/100 [01:04<20:29, 12.94s/it]

Epoch 4 | train loss 1.2005 | val loss 1.0791 | val acc 63.2200


  6%|▌         | 6/100 [01:18<20:38, 13.17s/it]

Epoch 5 | train loss 1.1346 | val loss 0.9725 | val acc 66.0200


  7%|▋         | 7/100 [01:31<20:37, 13.31s/it]

Epoch 6 | train loss 1.0556 | val loss 0.9569 | val acc 67.2600


  8%|▊         | 8/100 [01:44<20:11, 13.17s/it]

Epoch 7 | train loss 0.9892 | val loss 0.9566 | val acc 67.1800


  9%|▉         | 9/100 [01:58<20:14, 13.35s/it]

Epoch 8 | train loss 0.9044 | val loss 0.8862 | val acc 69.7000


 10%|█         | 10/100 [02:14<21:08, 14.10s/it]

Epoch 9 | train loss 0.9090 | val loss 0.8890 | val acc 69.4200


 11%|█         | 11/100 [02:28<20:56, 14.12s/it]

Epoch 10 | train loss 0.8721 | val loss 0.8481 | val acc 70.8100


 12%|█▏        | 12/100 [02:42<20:57, 14.29s/it]

Epoch 11 | train loss 0.8437 | val loss 0.8415 | val acc 70.6700


 13%|█▎        | 13/100 [02:57<20:52, 14.40s/it]

Epoch 12 | train loss 0.8228 | val loss 0.8547 | val acc 70.2200


 14%|█▍        | 14/100 [03:11<20:38, 14.40s/it]

Epoch 13 | train loss 0.7769 | val loss 0.8380 | val acc 71.4100


 15%|█▌        | 15/100 [03:26<20:22, 14.38s/it]

Epoch 14 | train loss 0.7387 | val loss 0.8323 | val acc 71.5800


 16%|█▌        | 16/100 [03:40<20:01, 14.31s/it]

Epoch 15 | train loss 0.7245 | val loss 0.8205 | val acc 72.9400


 17%|█▋        | 17/100 [03:54<19:48, 14.32s/it]

Epoch 16 | train loss 0.6812 | val loss 0.8220 | val acc 72.3900


 18%|█▊        | 18/100 [04:08<19:25, 14.21s/it]

Epoch 17 | train loss 0.6721 | val loss 0.8646 | val acc 71.4900


 19%|█▉        | 19/100 [04:23<19:14, 14.25s/it]

Epoch 18 | train loss 0.6454 | val loss 0.8433 | val acc 72.8900


 20%|██        | 20/100 [04:37<19:10, 14.38s/it]

Epoch 19 | train loss 0.6111 | val loss 0.8839 | val acc 71.9300


 21%|██        | 21/100 [04:52<18:54, 14.36s/it]

Epoch 20 | train loss 0.6514 | val loss 0.8576 | val acc 72.3800


 22%|██▏       | 22/100 [05:06<18:35, 14.31s/it]

Epoch 21 | train loss 0.6229 | val loss 0.8486 | val acc 72.3500


 23%|██▎       | 23/100 [05:20<18:20, 14.29s/it]

Epoch 22 | train loss 0.5798 | val loss 0.8551 | val acc 73.0600


 24%|██▍       | 24/100 [05:34<18:07, 14.31s/it]

Epoch 23 | train loss 0.5475 | val loss 0.8895 | val acc 72.7600


 25%|██▌       | 25/100 [05:49<18:06, 14.48s/it]

Epoch 24 | train loss 0.5465 | val loss 0.8885 | val acc 72.8900


 26%|██▌       | 26/100 [06:04<17:47, 14.43s/it]

Epoch 25 | train loss 0.5365 | val loss 0.8912 | val acc 72.7900


 27%|██▋       | 27/100 [06:18<17:35, 14.45s/it]

Epoch 26 | train loss 0.4974 | val loss 0.9245 | val acc 72.5200


 28%|██▊       | 28/100 [06:32<17:14, 14.36s/it]

Epoch 27 | train loss 0.5165 | val loss 0.9265 | val acc 72.9900


 29%|██▉       | 29/100 [06:47<16:59, 14.36s/it]

Epoch 28 | train loss 0.4736 | val loss 0.9336 | val acc 72.1900


 30%|███       | 30/100 [07:01<16:52, 14.46s/it]

Epoch 29 | train loss 0.4761 | val loss 0.9684 | val acc 72.5900


 31%|███       | 31/100 [07:16<16:36, 14.44s/it]

Epoch 30 | train loss 0.4795 | val loss 0.9327 | val acc 72.2000


 32%|███▏      | 32/100 [07:30<16:23, 14.47s/it]

Epoch 31 | train loss 0.4386 | val loss 0.9535 | val acc 73.4900


 33%|███▎      | 33/100 [07:45<16:07, 14.44s/it]

Epoch 32 | train loss 0.4182 | val loss 0.9632 | val acc 72.7000


 34%|███▍      | 34/100 [07:59<15:57, 14.52s/it]

Epoch 33 | train loss 0.4141 | val loss 1.0401 | val acc 72.1500


 35%|███▌      | 35/100 [08:14<15:42, 14.49s/it]

Epoch 34 | train loss 0.4231 | val loss 1.0023 | val acc 72.2600


 36%|███▌      | 36/100 [08:28<15:30, 14.54s/it]

Epoch 35 | train loss 0.3863 | val loss 0.9870 | val acc 73.2100


 37%|███▋      | 37/100 [08:43<15:10, 14.45s/it]

Epoch 36 | train loss 0.3605 | val loss 1.0002 | val acc 72.8000


 38%|███▊      | 38/100 [08:57<14:57, 14.47s/it]

Epoch 37 | train loss 0.3785 | val loss 1.0499 | val acc 72.3400


 39%|███▉      | 39/100 [09:11<14:38, 14.40s/it]

Epoch 38 | train loss 0.3984 | val loss 1.0936 | val acc 73.1900


 40%|████      | 40/100 [09:26<14:21, 14.36s/it]

Epoch 39 | train loss 0.4028 | val loss 1.0612 | val acc 72.3700


 41%|████      | 41/100 [09:40<14:04, 14.31s/it]

Epoch 40 | train loss 0.3734 | val loss 1.1078 | val acc 71.7300


 42%|████▏     | 42/100 [09:54<13:56, 14.42s/it]

Epoch 41 | train loss 0.3899 | val loss 1.1234 | val acc 72.3100


 43%|████▎     | 43/100 [10:09<13:35, 14.31s/it]

Epoch 42 | train loss 0.3677 | val loss 1.1055 | val acc 72.1500


 44%|████▍     | 44/100 [10:23<13:28, 14.44s/it]

Epoch 43 | train loss 0.3674 | val loss 1.1164 | val acc 72.6900


 45%|████▌     | 45/100 [10:38<13:10, 14.38s/it]

Epoch 44 | train loss 0.3705 | val loss 1.1481 | val acc 72.0700


 46%|████▌     | 46/100 [10:53<13:09, 14.61s/it]

Epoch 45 | train loss 0.3484 | val loss 1.1976 | val acc 73.1800


 47%|████▋     | 47/100 [11:09<13:21, 15.12s/it]

Epoch 46 | train loss 0.3878 | val loss 1.1826 | val acc 72.5600


 48%|████▊     | 48/100 [11:24<13:12, 15.24s/it]

Epoch 47 | train loss 0.3639 | val loss 1.1480 | val acc 72.7500


 49%|████▉     | 49/100 [11:39<12:53, 15.16s/it]

Epoch 48 | train loss 0.3583 | val loss 1.1440 | val acc 70.9100


 50%|█████     | 50/100 [11:54<12:21, 14.84s/it]

Epoch 49 | train loss 0.3363 | val loss 1.1799 | val acc 71.5900


 51%|█████     | 51/100 [12:08<11:57, 14.65s/it]

Epoch 50 | train loss 0.3429 | val loss 1.2244 | val acc 71.7800


 52%|█████▏    | 52/100 [12:22<11:34, 14.47s/it]

Epoch 51 | train loss 0.3060 | val loss 1.2128 | val acc 72.4300


 53%|█████▎    | 53/100 [12:36<11:19, 14.46s/it]

Epoch 52 | train loss 0.3098 | val loss 1.3544 | val acc 71.7400


 54%|█████▍    | 54/100 [12:51<11:08, 14.53s/it]

Epoch 53 | train loss 0.3739 | val loss 1.1932 | val acc 72.2300


 55%|█████▌    | 55/100 [13:06<11:03, 14.75s/it]

Epoch 54 | train loss 0.3032 | val loss 1.2948 | val acc 71.4700


 56%|█████▌    | 56/100 [13:21<10:54, 14.89s/it]

Epoch 55 | train loss 0.3111 | val loss 1.3468 | val acc 71.6600


 57%|█████▋    | 57/100 [13:36<10:42, 14.94s/it]

Epoch 56 | train loss 0.3809 | val loss 1.2590 | val acc 71.2500


 58%|█████▊    | 58/100 [13:51<10:25, 14.89s/it]

Epoch 57 | train loss 0.3513 | val loss 1.3611 | val acc 70.8500


 59%|█████▉    | 59/100 [14:06<10:05, 14.78s/it]

Epoch 58 | train loss 0.3781 | val loss 1.3824 | val acc 69.4000


 60%|██████    | 60/100 [14:20<09:47, 14.70s/it]

Epoch 59 | train loss 0.3925 | val loss 1.3514 | val acc 71.2000


 61%|██████    | 61/100 [14:35<09:29, 14.61s/it]

Epoch 60 | train loss 0.3551 | val loss 1.4081 | val acc 72.0000


 62%|██████▏   | 62/100 [14:50<09:19, 14.72s/it]

Epoch 61 | train loss 0.3775 | val loss 1.3295 | val acc 71.3200


 63%|██████▎   | 63/100 [15:04<09:05, 14.75s/it]

Epoch 62 | train loss 0.3255 | val loss 1.4525 | val acc 71.7400


 64%|██████▍   | 64/100 [15:19<08:49, 14.71s/it]

Epoch 63 | train loss 0.3432 | val loss 1.4079 | val acc 70.2300


 65%|██████▌   | 65/100 [15:34<08:35, 14.73s/it]

Epoch 64 | train loss 0.2876 | val loss 1.4095 | val acc 70.2300


 66%|██████▌   | 66/100 [15:48<08:18, 14.66s/it]

Epoch 65 | train loss 0.3291 | val loss 1.4338 | val acc 69.3200


 67%|██████▋   | 67/100 [16:03<08:02, 14.63s/it]

Epoch 66 | train loss 0.3286 | val loss 1.4377 | val acc 70.5100


 68%|██████▊   | 68/100 [16:18<07:50, 14.71s/it]

Epoch 67 | train loss 0.3532 | val loss 1.4916 | val acc 71.2200


 69%|██████▉   | 69/100 [16:33<07:42, 14.92s/it]

Epoch 68 | train loss 0.3620 | val loss 1.3555 | val acc 71.0200


 70%|███████   | 70/100 [16:49<07:32, 15.10s/it]

Epoch 69 | train loss 0.3269 | val loss 1.3963 | val acc 69.2900


 71%|███████   | 71/100 [17:04<07:17, 15.08s/it]

Epoch 70 | train loss 0.3638 | val loss 1.3363 | val acc 70.4400


 72%|███████▏  | 72/100 [17:19<07:00, 15.00s/it]

Epoch 71 | train loss 0.3617 | val loss 1.5237 | val acc 69.5800


 73%|███████▎  | 73/100 [17:34<06:45, 15.02s/it]

Epoch 72 | train loss 0.3229 | val loss 1.5514 | val acc 71.5100


 74%|███████▍  | 74/100 [17:49<06:31, 15.06s/it]

Epoch 73 | train loss 0.3923 | val loss 1.5439 | val acc 69.5700


 75%|███████▌  | 75/100 [18:04<06:15, 15.02s/it]

Epoch 74 | train loss 0.3984 | val loss 1.5132 | val acc 70.7900


 76%|███████▌  | 76/100 [18:19<05:59, 14.97s/it]

Epoch 75 | train loss 0.4439 | val loss 1.6208 | val acc 70.1600


 77%|███████▋  | 77/100 [18:33<05:42, 14.87s/it]

Epoch 76 | train loss 0.3768 | val loss 1.6575 | val acc 68.9200


 78%|███████▊  | 78/100 [18:48<05:26, 14.83s/it]

Epoch 77 | train loss 0.3617 | val loss 1.5366 | val acc 69.8400


 79%|███████▉  | 79/100 [19:03<05:11, 14.83s/it]

Epoch 78 | train loss 0.3483 | val loss 1.6930 | val acc 69.8700


 80%|████████  | 80/100 [19:17<04:54, 14.74s/it]

Epoch 79 | train loss 0.3686 | val loss 1.6116 | val acc 69.3100


 81%|████████  | 81/100 [19:32<04:40, 14.78s/it]

Epoch 80 | train loss 0.3276 | val loss 1.6313 | val acc 70.4600


 82%|████████▏ | 82/100 [19:48<04:29, 14.98s/it]

Epoch 81 | train loss 0.4572 | val loss 1.7381 | val acc 70.7400


 83%|████████▎ | 83/100 [20:03<04:15, 15.05s/it]

Epoch 82 | train loss 0.4274 | val loss 1.6243 | val acc 68.7100


 84%|████████▍ | 84/100 [20:18<04:00, 15.02s/it]

Epoch 83 | train loss 0.3452 | val loss 1.6536 | val acc 69.6900


 85%|████████▌ | 85/100 [20:33<03:44, 14.96s/it]

Epoch 84 | train loss 0.3872 | val loss 1.6216 | val acc 69.5600


 86%|████████▌ | 86/100 [20:48<03:30, 15.04s/it]

Epoch 85 | train loss 0.3893 | val loss 1.6706 | val acc 70.0900


 87%|████████▋ | 87/100 [21:04<03:18, 15.24s/it]

Epoch 86 | train loss 0.3876 | val loss 1.7275 | val acc 69.4500


 88%|████████▊ | 88/100 [21:19<03:02, 15.20s/it]

Epoch 87 | train loss 0.3364 | val loss 1.6980 | val acc 70.5700


 89%|████████▉ | 89/100 [21:33<02:45, 15.05s/it]

Epoch 88 | train loss 0.3802 | val loss 1.5858 | val acc 68.5000


 90%|█████████ | 90/100 [21:49<02:30, 15.09s/it]

Epoch 89 | train loss 0.3998 | val loss 1.7555 | val acc 70.1000


 91%|█████████ | 91/100 [22:03<02:15, 15.02s/it]

Epoch 90 | train loss 0.4324 | val loss 1.7687 | val acc 70.6900


 92%|█████████▏| 92/100 [22:18<01:59, 14.93s/it]

Epoch 91 | train loss 0.4337 | val loss 1.7741 | val acc 70.5500


 93%|█████████▎| 93/100 [22:33<01:44, 14.92s/it]

Epoch 92 | train loss 0.4292 | val loss 1.7107 | val acc 69.9100


 94%|█████████▍| 94/100 [22:48<01:29, 14.89s/it]

Epoch 93 | train loss 0.3374 | val loss 1.7526 | val acc 68.0500


 95%|█████████▌| 95/100 [23:03<01:14, 14.85s/it]

Epoch 94 | train loss 0.4699 | val loss 1.7082 | val acc 68.0000


 96%|█████████▌| 96/100 [23:18<00:59, 14.92s/it]

Epoch 95 | train loss 0.4489 | val loss 1.8331 | val acc 69.2300


 97%|█████████▋| 97/100 [23:33<00:44, 14.94s/it]

Epoch 96 | train loss 0.4274 | val loss 1.8883 | val acc 69.3100


 98%|█████████▊| 98/100 [23:47<00:29, 14.89s/it]

Epoch 97 | train loss 0.4777 | val loss 1.6849 | val acc 68.6700


 99%|█████████▉| 99/100 [24:02<00:14, 14.84s/it]

Epoch 98 | train loss 0.4140 | val loss 1.9685 | val acc 69.8800


100%|██████████| 100/100 [24:17<00:00, 14.58s/it]

Epoch 99 | train loss 0.3895 | val loss 1.7476 | val acc 70.6200





AttributeError: 'list' object has no attribute 'data'