In [1]:
import os
os.environ["WANDB_API_KEY"] = "dcf9600e0485401cbb0ddbb0f7be1c70f96b32ef"
os.environ["WANDB_MODE"] = "disabled"
import argparse
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb
import wandb
from sigma_layer import SigmaLinear, SigmaConv, SigmaView
from utils import get_dataset, gradient_centralization, normalize_along_axis, get_activation_function, compute_SCL_loss
import datetime

In [2]:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# Training scheme group
method_parser = parser.add_argument_group("Method")
method_parser.add_argument('--method', type=str, default='SIGMA', choices=['SIGMA', 'BP', 'FA'])
method_parser.add_argument('--architecture', type=str, default='lenet', choices=['lenet', 'vgg'])
method_parser.add_argument('--actfunc', type=str, default='elu', choices=['tanh', 'elu', 'relu'])
method_parser.add_argument('--conv_dim', type=int, default=32, choices=[32, 64])
# Dataset group
dataset_parser = parser.add_argument_group('Dataset')
dataset_parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'CIFAR10'])
dataset_parser.add_argument('--batchsize', type=int, default=64)
dataset_parser.add_argument('--splitratio', type=float, default=0.2)
# Training group # LR, optimizer, weight_decay, momentum
training_parser = parser.add_argument_group('Training')
training_parser.add_argument('--epochs', type=int, default=100)
training_parser.add_argument('--lr', type=float, default=1)
training_parser.add_argument('--optimizer', type=str, default='SGD', choices=['RMSprop', 'Adam', 'SGD'])
# Seed group
seed_parser = parser.add_argument_group('Seed')
seed_parser.add_argument('--seed', type=int, default=2023)
args, _ = parser.parse_known_args()

# Set run_name
run_name = f"{args.dataset}_{args.method}_conv{args.conv_dim}-act{args.actfunc}_{args.optimizer}_{args.lr}_{args.seed}"
time_stamp = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")

# Set wandb
wandb.init(
    project="opt-sigma",
    name=run_name,
    # track hyperparameters and run metadata
    config={
    "algorithm": args.method,
    "architecture": "SimpleCNN",
    "dataset": args.dataset,
    "epochs": args.epochs,
    "lr": args.lr,
    "optimizer": args.optimizer,
    "seed": args.seed,
    "conv_dim": args.conv_dim,
    "actfunc": args.actfunc,
    }
)

print(f"Run name: {run_name}")

# Set seed
torch.manual_seed(args.seed), np.random.seed(args.seed)

# Set device
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")

# Get dataset
train_loader, val_loader, test_loader = get_dataset(args)

class SigmaModel_SimpleCNN(nn.Module):
    def __init__(self, args):
        super(SigmaModel_SimpleCNN, self).__init__()
        if args.dataset == "CIFAR10":
            self.LN1 = torch.nn.LayerNorm((3,32,32), elementwise_affine=False) 
            self.conv1 = SigmaConv(args, 3, 32, 5, 3, 2)
            self.conv2 = SigmaConv(args, 32, 64, 3, 3, 2)
            self.view1 = SigmaView((64, 8, 8), 64 * 8 * 8)
            self.fc1 = SigmaLinear(args, 64 * 8 * 8, 128)
            self.fc2 = SigmaLinear(args, 128, 10)
            
        elif args.dataset == "MNIST":
            if args.architecture == "lenet":
                self.LN1 = torch.nn.LayerNorm((1,28,28), elementwise_affine=False) 
                self.conv1 = SigmaConv(args, 1, 32, 5, 2, 3)
                self.conv2 = SigmaConv(args, 32, 64, 5, 2, 3)
                self.view1 = SigmaView((64, 4, 4), 64 * 4 * 4)
                self.fc1 = SigmaLinear(args, 64 * 4 * 4, 512)
                self.fc2 = SigmaLinear(args, 512, 10)
                
        self.forward_params = list()
        self.backward_params = list()
        for layer in [self.conv1, self.conv2, self.fc1, self.fc2]:
            forward_params, backward_params = layer.get_parameters()
            self.forward_params += forward_params
            self.backward_params += backward_params

    def forward(self, x, detach_grad=False, return_activations=True):   
        # x = self.LN1(x)
        a1 = self.conv1(x, detach_grad)
        a2 = self.conv2(a1, detach_grad)
        a2 = self.view1(a2, detach_grad)
        a3 = self.fc1(a2, detach_grad)
        a4 = self.fc2(a3, detach_grad)
        return [a1, a2, a3, a4]
        
    def reverse(self, target, detach_grad=True, return_activations=True):
        if target.shape == torch.Size([10]): 
            target = F.one_hot(target, num_classes=10).float().to(target.device)
        b3 = self.fc2.reverse(target, detach_grad)
        b2 = self.fc1.reverse(b3, detach_grad)
        b2 = self.view1.reverse(b2, detach_grad)
        b1 = self.conv2.reverse(b2, detach_grad)
        return [b1, b2, b3, target]


class SigmaLoss(nn.Module):
    def __init__(self, args):
        super(SigmaLoss, self).__init__()
        self.args = args
        self.final_criteria = nn.CrossEntropyLoss()
        self.local_criteria = compute_SCL_loss
        self.method = args.method
        
    def forward(self, activations, signals, target, method="final"):
        if method == "local":
            loss = list()
            for act, sig in zip(activations[:-1], signals[:-1]):
                loss += [self.local_criteria(act, sig, target)]
            loss += [self.final_criteria(activations[-1], target)]
            return sum(loss), loss[-1].item()
        elif method == "final":
            loss = self.final_criteria(activations[-1], target) * 1.1 - 0.1
            return loss, loss.item()
        
model = SigmaModel_SimpleCNN(args)
model.to(device)
if args.optimizer == "SGD": 
    forward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
    backward_optimizer = optim.SGD(model.forward_params, lr=args.lr)
elif args.optimizer == "RMSprop": forward_optimizer = optim.RMSprop(model.forward_params, lr=args.lr, weight_decay=0)
elif args.optimizer == "Adam": forward_optimizer = optim.Adam(model.forward_params, lr=args.lr)
criteria = SigmaLoss(args)
    
with torch.no_grad():
    signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), return_activations=True)
    
best_val_loss = float('inf')

print("No LN1")
for epoch in tqdm(range(args.epochs)):
    
    train_loss, train_counter = 0, 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        if args.method == "SIGMA":
            activations = model(data.to(device), detach_grad=True)
            signals = model.reverse(torch.Tensor([0,1,2,3,4,5,6,7,8,9]).long().to(device), detach_grad=True)
            loss, loss_item = criteria(activations, signals, target.to(device), method="local")
        elif args.method == "BP":
            activations = model(data.to(device), detach_grad=False)
            loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), backward_optimizer.zero_grad(), loss.backward()
        gradient_centralization(model), forward_optimizer.step(), backward_optimizer.step()
        train_loss += loss_item * len(data)
        train_counter += len(data)

    wandb.log({'train_loss': train_loss / train_counter}, step=epoch)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > 100: continue
        activations = model(data.to(device), detach_grad=True)
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        forward_optimizer.zero_grad(), loss.backward(), forward_optimizer.step()

    # Validation
    val_correct, val_loss, val_counter = 0, 0, 0
    with torch.no_grad():
        for data, target in val_loader:
            val_counter += len(data)
            if args.method == "SIGMA":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="local")
                
            elif args.method == "BP":
                activations = model(data.to(device), detach_grad=True)
                _, loss_item = criteria(activations, signals, target.to(device), method="final")
            prediction = activations[-1].detach()
            _, predicted = torch.max(prediction, 1)
            val_correct += (predicted == target.to(device)).sum().item()
            val_loss += loss_item * len(data)

    wandb.log({'val_loss': val_loss / val_counter, 
               'val_acc': val_correct / val_counter, 
               }, step=epoch)
    
    print(f"""Epoch {epoch} | train loss {train_loss / train_counter:.4f} | val loss {val_loss / val_counter:.4f} | val acc {100 * val_correct / val_counter:.4f}""")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), f'./saved_models/{run_name}-{time_stamp}.pt')
        

# Eval on Test Set by loading the best model 
model.load_state_dict(torch.load(f'./saved_models/{run_name}-{time_stamp}.pt'))
model.eval()
correct, total = 0, 0
test_loss, test_counter = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        outputs = model(data.to(device), detach_grad=False)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == target.to(device)).sum().item()
        loss, loss_item = criteria(activations, signals, target.to(device), method="final")
        test_loss += loss_item * len(data)
        test_counter += len(data)

wandb.log({'test_loss': test_loss / test_counter,
ｃｖ
print(f'Epoch: {epoch}, Test Accuracy: {100 * correct / test_counter:.4f}%')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Run name: CIFAR10_SIGMA_conv32-actelu_SGD_1_2023
Files already downloaded and verified
Files already downloaded and verified
No LN1


  1%|          | 1/100 [00:13<21:53, 13.27s/it]

Epoch 0 | train loss 2.2178 | val loss 2.0178 | val acc 30.6200


  2%|▏         | 2/100 [00:27<22:11, 13.59s/it]

Epoch 1 | train loss 2.0582 | val loss 1.7265 | val acc 39.2900


  3%|▎         | 3/100 [00:41<22:14, 13.75s/it]

Epoch 2 | train loss 1.8743 | val loss 1.6028 | val acc 44.0400


  4%|▍         | 4/100 [00:54<21:59, 13.75s/it]

Epoch 3 | train loss 1.8098 | val loss 1.6477 | val acc 41.6000


  5%|▌         | 5/100 [01:08<21:59, 13.89s/it]

Epoch 4 | train loss 1.7366 | val loss 1.5404 | val acc 45.0100


  6%|▌         | 6/100 [01:22<21:26, 13.69s/it]

Epoch 5 | train loss 1.6680 | val loss 1.5069 | val acc 47.3100


  7%|▋         | 7/100 [01:35<20:54, 13.49s/it]

Epoch 6 | train loss 1.5717 | val loss 1.4420 | val acc 49.9600


  8%|▊         | 8/100 [01:48<20:34, 13.42s/it]

Epoch 7 | train loss 1.5261 | val loss 1.2913 | val acc 55.5200


  9%|▉         | 9/100 [02:02<20:31, 13.53s/it]

Epoch 8 | train loss 1.3500 | val loss 1.2612 | val acc 55.7400


 10%|█         | 10/100 [02:15<20:13, 13.48s/it]

Epoch 9 | train loss 1.3497 | val loss 1.2635 | val acc 55.1100


 11%|█         | 11/100 [02:29<20:05, 13.54s/it]

Epoch 10 | train loss 1.3222 | val loss 1.2236 | val acc 57.0000


 12%|█▏        | 12/100 [02:42<19:53, 13.56s/it]

Epoch 11 | train loss 1.2760 | val loss 1.2085 | val acc 58.3700


 13%|█▎        | 13/100 [02:56<19:43, 13.60s/it]

Epoch 12 | train loss 1.2610 | val loss 1.1985 | val acc 58.0300


 14%|█▍        | 14/100 [03:10<19:39, 13.72s/it]

Epoch 13 | train loss 1.2741 | val loss 1.1896 | val acc 58.9100


 15%|█▌        | 15/100 [03:24<19:25, 13.71s/it]

Epoch 14 | train loss 1.2365 | val loss 1.1932 | val acc 58.4700


 16%|█▌        | 16/100 [03:37<18:59, 13.56s/it]

Epoch 15 | train loss 1.2088 | val loss 1.1761 | val acc 59.1300


 17%|█▋        | 17/100 [03:50<18:40, 13.50s/it]

Epoch 16 | train loss 1.1978 | val loss 1.1789 | val acc 59.2300


 18%|█▊        | 18/100 [04:04<18:16, 13.37s/it]

Epoch 17 | train loss 1.1811 | val loss 1.1508 | val acc 60.2000


 19%|█▉        | 19/100 [04:17<18:01, 13.35s/it]

Epoch 18 | train loss 1.1924 | val loss 1.1621 | val acc 59.5100


 20%|██        | 20/100 [04:30<17:44, 13.30s/it]

Epoch 19 | train loss 1.1503 | val loss 1.1522 | val acc 60.4700


 21%|██        | 21/100 [04:43<17:32, 13.32s/it]

Epoch 20 | train loss 1.1603 | val loss 1.1463 | val acc 60.1900


 22%|██▏       | 22/100 [04:57<17:26, 13.41s/it]

Epoch 21 | train loss 1.1646 | val loss 1.1492 | val acc 60.7500


 23%|██▎       | 23/100 [05:10<17:05, 13.32s/it]

Epoch 22 | train loss 1.1628 | val loss 1.1445 | val acc 61.0200


 24%|██▍       | 24/100 [05:24<16:55, 13.37s/it]

Epoch 23 | train loss 1.1409 | val loss 1.1495 | val acc 60.4100


 25%|██▌       | 25/100 [05:37<16:40, 13.34s/it]

Epoch 24 | train loss 1.1320 | val loss 1.1449 | val acc 60.6400


 26%|██▌       | 26/100 [05:50<16:22, 13.28s/it]

Epoch 25 | train loss 1.1073 | val loss 1.1573 | val acc 60.1900


 27%|██▋       | 27/100 [06:03<16:06, 13.25s/it]

Epoch 26 | train loss 1.1113 | val loss 1.1571 | val acc 59.8400


 28%|██▊       | 28/100 [06:16<15:50, 13.21s/it]

Epoch 27 | train loss 1.1111 | val loss 1.1413 | val acc 60.9400


 29%|██▉       | 29/100 [06:30<15:42, 13.27s/it]

Epoch 28 | train loss 1.0936 | val loss 1.1381 | val acc 60.8200


 30%|███       | 30/100 [06:43<15:30, 13.29s/it]

Epoch 29 | train loss 1.0836 | val loss 1.1456 | val acc 60.9100


 31%|███       | 31/100 [06:56<15:12, 13.23s/it]

Epoch 30 | train loss 1.1069 | val loss 1.1391 | val acc 61.4300


 32%|███▏      | 32/100 [07:09<14:56, 13.19s/it]

Epoch 31 | train loss 1.1202 | val loss 1.1420 | val acc 61.0500


 33%|███▎      | 33/100 [07:22<14:41, 13.16s/it]

Epoch 32 | train loss 1.0922 | val loss 1.1413 | val acc 60.7800


 34%|███▍      | 34/100 [07:35<14:29, 13.17s/it]

Epoch 33 | train loss 1.0895 | val loss 1.1296 | val acc 60.9000


 35%|███▌      | 35/100 [07:49<14:21, 13.25s/it]

Epoch 34 | train loss 1.0668 | val loss 1.1235 | val acc 62.1400


 36%|███▌      | 36/100 [08:02<14:11, 13.30s/it]

Epoch 35 | train loss 1.0702 | val loss 1.1329 | val acc 61.5500


 37%|███▋      | 37/100 [08:16<13:56, 13.27s/it]

Epoch 36 | train loss 1.0602 | val loss 1.1486 | val acc 60.5500


 38%|███▊      | 38/100 [08:29<13:50, 13.40s/it]

Epoch 37 | train loss 1.0604 | val loss 1.1280 | val acc 61.5400


 39%|███▉      | 39/100 [08:43<13:43, 13.49s/it]

Epoch 38 | train loss 1.0781 | val loss 1.1342 | val acc 61.4500


 40%|████      | 40/100 [08:57<13:31, 13.52s/it]

Epoch 39 | train loss 1.0485 | val loss 1.1252 | val acc 61.6000


 41%|████      | 41/100 [09:10<13:11, 13.41s/it]

Epoch 40 | train loss 1.0445 | val loss 1.1327 | val acc 61.1300


 42%|████▏     | 42/100 [09:23<12:53, 13.34s/it]

Epoch 41 | train loss 1.0506 | val loss 1.1432 | val acc 61.0700


 43%|████▎     | 43/100 [09:36<12:38, 13.30s/it]

Epoch 42 | train loss 1.0403 | val loss 1.1525 | val acc 60.8300


 44%|████▍     | 44/100 [09:49<12:24, 13.29s/it]

Epoch 43 | train loss 1.0327 | val loss 1.1267 | val acc 61.3600


 45%|████▌     | 45/100 [10:03<12:09, 13.27s/it]

Epoch 44 | train loss 1.0465 | val loss 1.1212 | val acc 62.1700


 46%|████▌     | 46/100 [10:16<11:57, 13.29s/it]

Epoch 45 | train loss 1.0422 | val loss 1.1249 | val acc 62.3200


 47%|████▋     | 47/100 [10:29<11:41, 13.23s/it]

Epoch 46 | train loss 1.0604 | val loss 1.1282 | val acc 61.8200


 48%|████▊     | 48/100 [10:42<11:30, 13.27s/it]

Epoch 47 | train loss 1.0462 | val loss 1.1136 | val acc 62.3100


 49%|████▉     | 49/100 [10:56<11:20, 13.34s/it]

Epoch 48 | train loss 1.0243 | val loss 1.1010 | val acc 62.5400


 50%|█████     | 50/100 [11:09<11:08, 13.38s/it]

Epoch 49 | train loss 1.0269 | val loss 1.1512 | val acc 60.8800


 51%|█████     | 51/100 [11:22<10:50, 13.29s/it]

Epoch 50 | train loss 1.0425 | val loss 1.1344 | val acc 61.7500


 52%|█████▏    | 52/100 [11:36<10:36, 13.26s/it]

Epoch 51 | train loss 1.0211 | val loss 1.1241 | val acc 61.9700


 53%|█████▎    | 53/100 [11:49<10:20, 13.21s/it]

Epoch 52 | train loss 1.0254 | val loss 1.1193 | val acc 62.2200


 54%|█████▍    | 54/100 [12:02<10:06, 13.19s/it]

Epoch 53 | train loss 1.0279 | val loss 1.1288 | val acc 61.9700


 55%|█████▌    | 55/100 [12:15<09:52, 13.17s/it]

Epoch 54 | train loss 1.0185 | val loss 1.1110 | val acc 62.2200


 56%|█████▌    | 56/100 [12:28<09:39, 13.17s/it]

Epoch 55 | train loss 1.0276 | val loss 1.1271 | val acc 61.8200


 57%|█████▋    | 57/100 [12:41<09:28, 13.21s/it]

Epoch 56 | train loss 1.0374 | val loss 1.1207 | val acc 62.3900


 58%|█████▊    | 58/100 [12:55<09:13, 13.18s/it]

Epoch 57 | train loss 1.0030 | val loss 1.1086 | val acc 62.7400


 59%|█████▉    | 59/100 [13:08<09:00, 13.17s/it]

Epoch 58 | train loss 1.0169 | val loss 1.1363 | val acc 61.9200


 60%|██████    | 60/100 [13:21<08:45, 13.14s/it]

Epoch 59 | train loss 0.9875 | val loss 1.1217 | val acc 62.3400


 61%|██████    | 61/100 [13:34<08:32, 13.13s/it]

Epoch 60 | train loss 0.9929 | val loss 1.1068 | val acc 62.6500


 62%|██████▏   | 62/100 [13:47<08:18, 13.12s/it]

Epoch 61 | train loss 1.0296 | val loss 1.1210 | val acc 61.8700


 63%|██████▎   | 63/100 [14:00<08:05, 13.12s/it]

Epoch 62 | train loss 0.9823 | val loss 1.1080 | val acc 62.8200


 64%|██████▍   | 64/100 [14:13<07:51, 13.11s/it]

Epoch 63 | train loss 0.9786 | val loss 1.1103 | val acc 62.4100


 65%|██████▌   | 65/100 [14:26<07:39, 13.13s/it]

Epoch 64 | train loss 0.9836 | val loss 1.1443 | val acc 61.2800


 66%|██████▌   | 66/100 [14:40<07:27, 13.16s/it]

Epoch 65 | train loss 1.0044 | val loss 1.1047 | val acc 62.8200


 67%|██████▋   | 67/100 [14:53<07:15, 13.20s/it]

Epoch 66 | train loss 0.9885 | val loss 1.1099 | val acc 62.9700


 68%|██████▊   | 68/100 [15:06<07:01, 13.16s/it]

Epoch 67 | train loss 0.9599 | val loss 1.1189 | val acc 62.1800


 69%|██████▉   | 69/100 [15:20<06:54, 13.36s/it]

Epoch 68 | train loss 1.0217 | val loss 1.1041 | val acc 62.4400


 70%|███████   | 70/100 [15:34<06:51, 13.72s/it]

Epoch 69 | train loss 0.9865 | val loss 1.1017 | val acc 63.1000


 71%|███████   | 71/100 [15:48<06:40, 13.82s/it]

Epoch 70 | train loss 1.0020 | val loss 1.1023 | val acc 63.1100


 72%|███████▏  | 72/100 [16:02<06:23, 13.70s/it]

Epoch 71 | train loss 1.0102 | val loss 1.1420 | val acc 61.8500


 73%|███████▎  | 73/100 [16:15<06:08, 13.67s/it]

Epoch 72 | train loss 0.9916 | val loss 1.1052 | val acc 63.0200


 74%|███████▍  | 74/100 [16:29<05:52, 13.57s/it]

Epoch 73 | train loss 0.9898 | val loss 1.1031 | val acc 62.9600


 75%|███████▌  | 75/100 [16:42<05:38, 13.53s/it]

Epoch 74 | train loss 0.9998 | val loss 1.1078 | val acc 62.6200


 76%|███████▌  | 76/100 [16:55<05:23, 13.46s/it]

Epoch 75 | train loss 0.9779 | val loss 1.1026 | val acc 63.0200


 76%|███████▌  | 76/100 [17:07<05:24, 13.52s/it]


KeyboardInterrupt: 

In [None]:
for a, s in zip(activations, signals):
    print(a.shape, s.shape)

torch.Size([64, 32, 16, 16]) torch.Size([10, 32, 32, 32])
torch.Size([64, 4096]) torch.Size([10, 64, 8, 8])
torch.Size([64, 128]) torch.Size([10, 128])
torch.Size([64, 10]) torch.Size([10, 10])


In [None]:
Epoch 0 | train loss 2.1706 | val loss 1.7251 | val acc 38.4000
Epoch 1 | train loss 1.9300 | val loss 1.5896 | val acc 44.4500
Epoch 2 | train loss 1.8217 | val loss 1.4883 | val acc 47.6600
Epoch 3 | train loss 1.7205 | val loss 1.4662 | val acc 48.2700
Epoch 4 | train loss 1.6620 | val loss 1.3580 | val acc 52.1400

Epoch 5 | train loss 1.5822 | val loss 1.3429 | val acc 53.0200
Epoch 6 | train loss 1.5124 | val loss 1.3322 | val acc 52.9600
Epoch 7 | train loss 1.4301 | val loss 1.2645 | val acc 55.5300
Epoch 8 | train loss 1.4169 | val loss 1.2650 | val acc 55.7300
Epoch 9 | train loss 1.3360 | val loss 1.2396 | val acc 56.5400

Epoch 10 | train loss 1.3369 | val loss 1.2339 | val acc 57.2100
Epoch 11 | train loss 1.3077 | val loss 1.2344 | val acc 57.2000
Epoch 12 | train loss 1.2832 | val loss 1.2064 | val acc 58.1100
Epoch 13 | train loss 1.2695 | val loss 1.2016 | val acc 58.7200
Epoch 14 | train loss 1.2429 | val loss 1.1799 | val acc 59.7700

Epoch 15 | train loss 1.2381 | val loss 1.1819 | val acc 59.5900
Epoch 16 | train loss 1.2267 | val loss 1.1791 | val acc 59.5400
Epoch 17 | train loss 1.2239 | val loss 1.1724 | val acc 59.5000
Epoch 18 | train loss 1.1911 | val loss 1.1605 | val acc 60.5000
Epoch 19 | train loss 1.1702 | val loss 1.1587 | val acc 60.6300
Epoch 20 | train loss 1.1602 | val loss 1.1574 | val acc 60.8900

In [None]:
Epoch 0 | train loss 2.1313 | val loss 1.7911 | val acc 35.1500
  2%|▏         | 2/100 [00:26<21:38, 13.25s/it]
Epoch 1 | train loss 1.6918 | val loss 1.5604 | val acc 44.5300
  3%|▎         | 3/100 [00:39<21:30, 13.30s/it]
Epoch 2 | train loss 1.5431 | val loss 1.4494 | val acc 48.9800
  4%|▍         | 4/100 [00:52<21:11, 13.24s/it]
Epoch 3 | train loss 1.4475 | val loss 1.3936 | val acc 50.8200
  5%|▌         | 5/100 [01:06<20:56, 13.22s/it]
Epoch 4 | train loss 1.4169 | val loss 1.3353 | val acc 53.3300
  6%|▌         | 6/100 [01:19<20:36, 13.15s/it]
Epoch 5 | train loss 1.3502 | val loss 1.2823 | val acc 55.4500
  7%|▋         | 7/100 [01:32<20:28, 13.21s/it]
Epoch 6 | train loss 1.3133 | val loss 1.2575 | val acc 57.4800
  8%|▊         | 8/100 [01:45<20:23, 13.30s/it]
Epoch 7 | train loss 1.2808 | val loss 1.2338 | val acc 58.1700
  9%|▉         | 9/100 [01:59<20:06, 13.26s/it]
Epoch 8 | train loss 1.2644 | val loss 1.2085 | val acc 59.1500
 10%|█         | 10/100 [02:12<19:51, 13.24s/it]
Epoch 9 | train loss 1.2249 | val loss 1.2015 | val acc 60.5500
 11%|█         | 11/100 [02:25<19:33, 13.19s/it]
Epoch 10 | train loss 1.2096 | val loss 1.1854 | val acc 61.0400
 12%|█▏        | 12/100 [02:38<19:22, 13.21s/it]
Epoch 11 | train loss 1.2114 | val loss 1.1818 | val acc 61.3400
 13%|█▎        | 13/100 [02:51<19:07, 13.19s/it]
Epoch 12 | train loss 1.1965 | val loss 1.1885 | val acc 62.0300
 14%|█▍        | 14/100 [03:04<18:47, 13.11s/it]
Epoch 13 | train loss 1.1937 | val loss 1.1811 | val acc 61.8700
 15%|█▌        | 15/100 [03:18<18:41, 13.20s/it]
Epoch 14 | train loss 1.2066 | val loss 1.1872 | val acc 62.3500
 16%|█▌        | 16/100 [03:31<18:22, 13.12s/it]
Epoch 15 | train loss 1.1988 | val loss 1.1774 | val acc 62.9200
 17%|█▋        | 17/100 [03:44<18:10, 13.14s/it]
Epoch 16 | train loss 1.1958 | val loss 1.1849 | val acc 61.6400
 18%|█▊        | 18/100 [03:57<17:55, 13.12s/it]
Epoch 17 | train loss 1.2015 | val loss 1.1870 | val acc 63.2400
 19%|█▉        | 19/100 [04:10<17:37, 13.06s/it]
Epoch 18 | train loss 1.1976 | val loss 1.1970 | val acc 63.2700
 20%|██        | 20/100 [04:23<17:20, 13.01s/it]
Epoch 19 | train loss 1.1864 | val loss 1.1861 | val acc 63.4800
 21%|██        | 21/100 [04:36<17:05, 12.98s/it]
Epoch 20 | train loss 1.1922 | val loss 1.2039 | val acc 62.8200