In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [23]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='SIGMA', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--architecture', type=str, default='lenet', choices=['lenet', 'vgg'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.LN1 = torch.nn.LayerNorm((3,32,32), elementwise_affine=False) 
            self.conv1 = SigmaConv(args, 3, 32, 5, 3, 2)
            self.conv2 = SigmaConv(args, 32, 64, 3, 3, 2)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(args, 64 * 8 * 8, 128)
            self.fc2 = SigmaLinear(args, 128, 10)
            
        elif args.dataset == "MNIST":
            if args.architecture == "lenet":
                self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
                self.conv1 = SigmaConv(args, 1, 32, 5, 2, 3)
                self.conv2 = SigmaConv(args, 32, 64, 5, 2, 3)
                self.view1 = SigmaView((64, 4, 4), 64 * 4 * 4)
                self.fc1 = SigmaLinear(args, 64 * 4 * 4, 512)
                self.fc2 = SigmaLinear(args, 512, 10)

        # orthoginal init self.fc1 and self.fc2
        init.orthogonal_(self.fc1.backward_layer.weight)
        init.sparse_(self.fc2.backward_layer.weight, sparsity=0.5)
        # init.orthogonal_(self.fc2.backward_layer.weight)
        # init.zeros_(self.fc1.forward_layer.weight)
        # init.zeros_(self.fc2.forward_layer.weight)
        # init.zeros_(self.conv2.forward_layer.weight)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        # x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target)
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr, weight_decay=0.0005)
    backward_optimizer = optim.SGD(model.backward_params, lr=args.lr*8)

elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
# for epoch in tqdm(range(args.epochs)):
for epoch in range(args.epochs):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model)
        forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        
# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,})
print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Run name: CIFAR10_SIGMA_conv32-actelu_SGD_1_2023
Files already downloaded and verified
Files already downloaded and verified
No LN1


  1%|          | 1/100 [00:12<21:26, 12.99s/it]

Epoch 0 | train loss 2.1313 | val loss 1.7911 | val acc 35.1500


  2%|▏         | 2/100 [00:26<21:38, 13.25s/it]

Epoch 1 | train loss 1.6918 | val loss 1.5604 | val acc 44.5300


  3%|▎         | 3/100 [00:39<21:30, 13.30s/it]

Epoch 2 | train loss 1.5431 | val loss 1.4494 | val acc 48.9800


  4%|▍         | 4/100 [00:52<21:11, 13.24s/it]

Epoch 3 | train loss 1.4475 | val loss 1.3936 | val acc 50.8200


  5%|▌         | 5/100 [01:06<20:56, 13.22s/it]

Epoch 4 | train loss 1.4169 | val loss 1.3353 | val acc 53.3300


  6%|▌         | 6/100 [01:19<20:36, 13.15s/it]

Epoch 5 | train loss 1.3502 | val loss 1.2823 | val acc 55.4500


  7%|▋         | 7/100 [01:32<20:28, 13.21s/it]

Epoch 6 | train loss 1.3133 | val loss 1.2575 | val acc 57.4800


  8%|▊         | 8/100 [01:45<20:23, 13.30s/it]

Epoch 7 | train loss 1.2808 | val loss 1.2338 | val acc 58.1700


  9%|▉         | 9/100 [01:59<20:06, 13.26s/it]

Epoch 8 | train loss 1.2644 | val loss 1.2085 | val acc 59.1500


 10%|█         | 10/100 [02:12<19:51, 13.24s/it]

Epoch 9 | train loss 1.2249 | val loss 1.2015 | val acc 60.5500


 11%|█         | 11/100 [02:25<19:33, 13.19s/it]

Epoch 10 | train loss 1.2096 | val loss 1.1854 | val acc 61.0400


 12%|█▏        | 12/100 [02:38<19:22, 13.21s/it]

Epoch 11 | train loss 1.2114 | val loss 1.1818 | val acc 61.3400


 13%|█▎        | 13/100 [02:51<19:07, 13.19s/it]

Epoch 12 | train loss 1.1965 | val loss 1.1885 | val acc 62.0300


 14%|█▍        | 14/100 [03:04<18:47, 13.11s/it]

Epoch 13 | train loss 1.1937 | val loss 1.1811 | val acc 61.8700


 15%|█▌        | 15/100 [03:18<18:41, 13.20s/it]

Epoch 14 | train loss 1.2066 | val loss 1.1872 | val acc 62.3500


 16%|█▌        | 16/100 [03:31<18:22, 13.12s/it]

Epoch 15 | train loss 1.1988 | val loss 1.1774 | val acc 62.9200


 17%|█▋        | 17/100 [03:44<18:10, 13.14s/it]

Epoch 16 | train loss 1.1958 | val loss 1.1849 | val acc 61.6400


 18%|█▊        | 18/100 [03:57<17:55, 13.12s/it]

Epoch 17 | train loss 1.2015 | val loss 1.1870 | val acc 63.2400


 19%|█▉        | 19/100 [04:10<17:37, 13.06s/it]

Epoch 18 | train loss 1.1976 | val loss 1.1970 | val acc 63.2700


 20%|██        | 20/100 [04:23<17:20, 13.01s/it]

Epoch 19 | train loss 1.1864 | val loss 1.1861 | val acc 63.4800


 21%|██        | 21/100 [04:36<17:05, 12.98s/it]

Epoch 20 | train loss 1.1922 | val loss 1.2039 | val acc 62.8200


 22%|██▏       | 22/100 [04:48<16:51, 12.97s/it]

Epoch 21 | train loss 1.1905 | val loss 1.2230 | val acc 63.3000


 23%|██▎       | 23/100 [05:01<16:38, 12.96s/it]

Epoch 22 | train loss 1.2149 | val loss 1.2149 | val acc 63.1200


 24%|██▍       | 24/100 [05:14<16:23, 12.94s/it]

Epoch 23 | train loss 1.2070 | val loss 1.2391 | val acc 62.6400


 25%|██▌       | 25/100 [05:27<16:10, 12.94s/it]

Epoch 24 | train loss 1.2330 | val loss 1.2317 | val acc 64.2200


 26%|██▌       | 26/100 [05:40<16:02, 13.00s/it]

Epoch 25 | train loss 1.2264 | val loss 1.2333 | val acc 64.1500


 27%|██▋       | 27/100 [05:53<15:48, 12.99s/it]

Epoch 26 | train loss 1.2246 | val loss 1.2317 | val acc 63.4300


 28%|██▊       | 28/100 [06:06<15:33, 12.97s/it]

Epoch 27 | train loss 1.2590 | val loss 1.2561 | val acc 63.3500


 29%|██▉       | 29/100 [06:19<15:23, 13.00s/it]

Epoch 28 | train loss 1.2745 | val loss 1.2716 | val acc 64.2500


 30%|███       | 30/100 [06:32<15:11, 13.02s/it]

Epoch 29 | train loss 1.2759 | val loss 1.2868 | val acc 64.5600


 31%|███       | 31/100 [06:46<15:00, 13.05s/it]

Epoch 30 | train loss 1.2925 | val loss 1.2855 | val acc 64.8400


 32%|███▏      | 32/100 [06:59<14:54, 13.15s/it]

Epoch 31 | train loss 1.2916 | val loss 1.2984 | val acc 64.2300


 33%|███▎      | 33/100 [07:12<14:47, 13.25s/it]

Epoch 32 | train loss 1.3095 | val loss 1.3118 | val acc 63.5200


 34%|███▍      | 34/100 [07:26<14:32, 13.22s/it]

Epoch 33 | train loss 1.3250 | val loss 1.3384 | val acc 63.9700


 35%|███▌      | 35/100 [07:39<14:28, 13.36s/it]

Epoch 34 | train loss 1.3394 | val loss 1.3513 | val acc 63.0800


 36%|███▌      | 36/100 [07:53<14:16, 13.39s/it]

Epoch 35 | train loss 1.3545 | val loss 1.3703 | val acc 62.8900


 37%|███▋      | 37/100 [08:06<14:05, 13.42s/it]

Epoch 36 | train loss 1.3770 | val loss 1.3864 | val acc 63.9000


 38%|███▊      | 38/100 [08:19<13:45, 13.32s/it]

Epoch 37 | train loss 1.3772 | val loss 1.4320 | val acc 63.4900


 39%|███▉      | 39/100 [08:33<13:39, 13.44s/it]

Epoch 38 | train loss 1.4103 | val loss 1.4080 | val acc 62.8300


 40%|████      | 40/100 [08:46<13:17, 13.29s/it]

Epoch 39 | train loss 1.4175 | val loss 1.4344 | val acc 64.6900


 41%|████      | 41/100 [08:59<13:08, 13.36s/it]

Epoch 40 | train loss 1.4421 | val loss 1.4662 | val acc 63.1700


 42%|████▏     | 42/100 [09:13<12:55, 13.38s/it]

Epoch 41 | train loss 1.4446 | val loss 1.4779 | val acc 61.8800


 43%|████▎     | 43/100 [09:26<12:42, 13.38s/it]

Epoch 42 | train loss 1.4573 | val loss 1.4676 | val acc 64.2700


 44%|████▍     | 44/100 [09:40<12:42, 13.61s/it]

Epoch 43 | train loss 1.4759 | val loss 1.4922 | val acc 64.2200


 45%|████▌     | 45/100 [09:55<12:38, 13.80s/it]

Epoch 44 | train loss 1.4993 | val loss 1.5011 | val acc 62.9500


 46%|████▌     | 46/100 [10:08<12:17, 13.66s/it]

Epoch 45 | train loss 1.5094 | val loss 1.5283 | val acc 62.9100


 47%|████▋     | 47/100 [10:21<12:00, 13.59s/it]

Epoch 46 | train loss 1.5186 | val loss 1.5277 | val acc 62.8600


 48%|████▊     | 48/100 [10:35<11:45, 13.56s/it]

Epoch 47 | train loss 1.5272 | val loss 1.5512 | val acc 61.2100


 49%|████▉     | 49/100 [10:48<11:29, 13.53s/it]

Epoch 48 | train loss 1.5453 | val loss 1.5830 | val acc 63.0900


 50%|█████     | 50/100 [11:02<11:14, 13.49s/it]

Epoch 49 | train loss 1.5513 | val loss 1.5793 | val acc 62.6900


 51%|█████     | 51/100 [11:16<11:11, 13.70s/it]

Epoch 50 | train loss 1.5743 | val loss 1.5733 | val acc 63.3000


 52%|█████▏    | 52/100 [11:30<10:57, 13.69s/it]

Epoch 51 | train loss 1.5738 | val loss 1.5837 | val acc 62.5300


 53%|█████▎    | 53/100 [11:43<10:41, 13.64s/it]

Epoch 52 | train loss 1.5839 | val loss 1.6096 | val acc 62.9000


 54%|█████▍    | 54/100 [11:57<10:25, 13.60s/it]

Epoch 53 | train loss 1.5838 | val loss 1.6104 | val acc 63.1900


 55%|█████▌    | 55/100 [12:10<10:08, 13.52s/it]

Epoch 54 | train loss 1.5877 | val loss 1.6038 | val acc 62.6500


 56%|█████▌    | 56/100 [12:24<09:55, 13.53s/it]

Epoch 55 | train loss 1.6013 | val loss 1.6144 | val acc 63.0600


 57%|█████▋    | 57/100 [12:37<09:43, 13.58s/it]

Epoch 56 | train loss 1.6034 | val loss 1.6163 | val acc 60.5800


 58%|█████▊    | 58/100 [12:51<09:31, 13.60s/it]

Epoch 57 | train loss 1.6015 | val loss 1.6197 | val acc 62.8500


 59%|█████▉    | 59/100 [13:05<09:19, 13.65s/it]

Epoch 58 | train loss 1.6387 | val loss 1.6368 | val acc 63.3000


 60%|██████    | 60/100 [13:18<09:05, 13.65s/it]

Epoch 59 | train loss 1.6348 | val loss 1.6480 | val acc 61.6000


 61%|██████    | 61/100 [13:32<08:57, 13.79s/it]

Epoch 60 | train loss 1.6414 | val loss 1.6674 | val acc 62.8900


 62%|██████▏   | 62/100 [13:46<08:43, 13.78s/it]

Epoch 61 | train loss 1.6436 | val loss 1.6605 | val acc 63.0600


 63%|██████▎   | 63/100 [14:00<08:33, 13.89s/it]

Epoch 62 | train loss 1.6509 | val loss 1.6680 | val acc 63.2100


 64%|██████▍   | 64/100 [14:14<08:15, 13.77s/it]

Epoch 63 | train loss 1.6539 | val loss 1.7041 | val acc 62.9500


 65%|██████▌   | 65/100 [14:28<08:06, 13.91s/it]

Epoch 64 | train loss 1.6519 | val loss 1.6720 | val acc 61.9100


 66%|██████▌   | 66/100 [14:42<07:52, 13.91s/it]

Epoch 65 | train loss 1.6599 | val loss 1.6618 | val acc 60.9000


 67%|██████▋   | 67/100 [14:56<07:42, 14.02s/it]

Epoch 66 | train loss 1.6640 | val loss 1.6694 | val acc 63.2700


 68%|██████▊   | 68/100 [15:10<07:27, 14.00s/it]

Epoch 67 | train loss 1.6605 | val loss 1.6567 | val acc 61.9700


 69%|██████▉   | 69/100 [15:24<07:13, 13.98s/it]

Epoch 68 | train loss 1.6734 | val loss 1.6630 | val acc 63.1800


 70%|███████   | 70/100 [15:38<06:54, 13.83s/it]

Epoch 69 | train loss 1.6616 | val loss 1.6994 | val acc 60.9000


 71%|███████   | 71/100 [15:52<06:43, 13.92s/it]

Epoch 70 | train loss 1.6870 | val loss 1.7254 | val acc 62.5100


 72%|███████▏  | 72/100 [16:06<06:32, 14.03s/it]

Epoch 71 | train loss 1.6843 | val loss 1.6935 | val acc 60.0600


 73%|███████▎  | 73/100 [16:19<06:14, 13.86s/it]

Epoch 72 | train loss 1.6865 | val loss 1.6976 | val acc 62.0500


 74%|███████▍  | 74/100 [16:33<05:59, 13.82s/it]

Epoch 73 | train loss 1.6866 | val loss 1.7141 | val acc 59.7900


 75%|███████▌  | 75/100 [16:48<05:50, 14.02s/it]

Epoch 74 | train loss 1.6829 | val loss 1.7032 | val acc 63.3700


 76%|███████▌  | 76/100 [17:02<05:36, 14.01s/it]

Epoch 75 | train loss 1.6979 | val loss 1.7325 | val acc 58.3900


 77%|███████▋  | 77/100 [17:15<05:18, 13.83s/it]

Epoch 76 | train loss 1.7030 | val loss 1.7342 | val acc 62.2700


 78%|███████▊  | 78/100 [17:28<05:00, 13.65s/it]

Epoch 77 | train loss 1.6905 | val loss 1.7120 | val acc 61.5400


 79%|███████▉  | 79/100 [17:42<04:47, 13.69s/it]

Epoch 78 | train loss 1.6848 | val loss 1.6828 | val acc 62.2400


 80%|████████  | 80/100 [17:56<04:38, 13.90s/it]

Epoch 79 | train loss 1.6917 | val loss 1.7020 | val acc 63.8600


 81%|████████  | 81/100 [18:10<04:21, 13.75s/it]

Epoch 80 | train loss 1.6933 | val loss 1.6819 | val acc 61.4800


 82%|████████▏ | 82/100 [18:23<04:05, 13.65s/it]

Epoch 81 | train loss 1.6879 | val loss 1.7057 | val acc 63.0300


 83%|████████▎ | 83/100 [18:37<03:50, 13.58s/it]

Epoch 82 | train loss 1.6980 | val loss 1.6940 | val acc 62.5700


 84%|████████▍ | 84/100 [18:51<03:38, 13.65s/it]

Epoch 83 | train loss 1.6901 | val loss 1.6776 | val acc 62.9700


 85%|████████▌ | 85/100 [19:04<03:26, 13.74s/it]

Epoch 84 | train loss 1.6843 | val loss 1.7007 | val acc 63.6900


 86%|████████▌ | 86/100 [19:18<03:10, 13.59s/it]

Epoch 85 | train loss 1.7031 | val loss 1.6850 | val acc 60.8700


 87%|████████▋ | 87/100 [19:32<02:58, 13.69s/it]

Epoch 86 | train loss 1.6835 | val loss 1.6996 | val acc 63.3300


 88%|████████▊ | 88/100 [19:45<02:42, 13.57s/it]

Epoch 87 | train loss 1.6956 | val loss 1.7041 | val acc 61.6000


 89%|████████▉ | 89/100 [19:58<02:28, 13.49s/it]

Epoch 88 | train loss 1.6880 | val loss 1.7004 | val acc 61.4400


 90%|█████████ | 90/100 [20:12<02:14, 13.43s/it]

Epoch 89 | train loss 1.6977 | val loss 1.6945 | val acc 61.2100


 91%|█████████ | 91/100 [20:25<02:01, 13.48s/it]

Epoch 90 | train loss 1.6868 | val loss 1.7048 | val acc 61.1700


 92%|█████████▏| 92/100 [20:38<01:47, 13.44s/it]

Epoch 91 | train loss 1.6890 | val loss 1.7136 | val acc 61.0200


 93%|█████████▎| 93/100 [20:52<01:33, 13.38s/it]

Epoch 92 | train loss 1.6964 | val loss 1.7039 | val acc 62.0500


 94%|█████████▍| 94/100 [21:05<01:20, 13.34s/it]

Epoch 93 | train loss 1.6974 | val loss 1.7127 | val acc 61.0000


 95%|█████████▌| 95/100 [21:18<01:06, 13.34s/it]

Epoch 94 | train loss 1.7084 | val loss 1.7035 | val acc 62.9300


 96%|█████████▌| 96/100 [21:32<00:53, 13.45s/it]

Epoch 95 | train loss 1.6986 | val loss 1.7025 | val acc 59.8000


 97%|█████████▋| 97/100 [21:46<00:40, 13.48s/it]

Epoch 96 | train loss 1.6997 | val loss 1.6834 | val acc 62.0200


 98%|█████████▊| 98/100 [21:59<00:26, 13.39s/it]

Epoch 97 | train loss 1.7052 | val loss 1.7114 | val acc 62.4000


 99%|█████████▉| 99/100 [22:12<00:13, 13.32s/it]

Epoch 98 | train loss 1.6986 | val loss 1.6812 | val acc 60.2100


100%|██████████| 100/100 [22:25<00:00, 13.46s/it]

Epoch 99 | train loss 1.6986 | val loss 1.7209 | val acc 60.5600





AttributeError: 'list' object has no attribute 'data'