# CapsNet

Source: https://github.com/jindongwang/Pytorch-CapsuleNet/blob/master/test_capsnet.py

In [1]:
import sys 
sys.path.insert(0, "helper")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision
from tqdm import tqdm

# =============================================================================
# datasceyence
# =============================================================================
from helper.visualisation.feature_map import *
from helper.model.capsnet import CapsNet

from helper.data.mnist import DataMNIST
from helper.data.retinamnist import DataRetinaMNIST
from helper.data.octmnist import DataOCTMNIST

#train_kwargs["device"] = True if torch.cuda.is_available() else False
#train_kwargs["batch_size"] = 5
# train_kwargs["epochs"] = 30
#LEARNING_RATE = 0.01
#MOMENTUM = 0.9



In [2]:
seed = 1997 # was 19 before

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

print("cuda available:", torch.cuda.is_available())

debug_model = False

print('torch 2.0.0 ==', torch.__version__=='2.0.0' , '->', torch.__version__)

cuda available: True
torch 2.0.0 == False -> 1.13.1


# Settings

In [3]:
model_kwargs = { # for 1 Channel
    'in_channels' : 1,
    'n_classes': None, # filled in the dataset
    'criterion': torch.nn.CrossEntropyLoss(),# torch.nn.BCEWithLogitsLoss(),
    'optimizer': "sgd", # sgd adamw
    'base_lr': 0.001,
    'min_lr' : 0.00001,
    'momentum' : 0.9,
    'lr_update' : 100,
    
    # CNN (cnn)
    'cnn_in_channels' : 1, # duplicate ...
    'cnn_out_channels' : 256,
    'cnn_kernel_size' : 9,

    # Primary Capsule (pc)
    'pc_num_capsules' : 8,
    'pc_in_channels' : 256,
    'pc_out_channels' : 32,
    'pc_kernel_size' : 9,
    'pc_num_routes' : 32 * 6 * 6,

    # Digit Capsule (dc)
    'dc_num_capsules' : 10,
    'dc_num_routes' : 32 * 6 * 6,
    'dc_in_channels' : 8,
    'dc_out_channels' : 16,

    # Decoder
    'input_width' : 28,
    'input_height' : 28,
}


train_kwargs = {
    'result_path': "examples/example_results", # "example_results/lightning_logs", # not in use??
    'exp_name': "debug_mnist", # must include oct or retina
    'load_ckpt_file' : 'version_18/checkpoints/epoch=4-unpruned=192-val_f1=0.12.ckpt', # "version_0/checkpoints/epoch=94-unpruned=1600-val_f1=0.67.ckpt", # 'version_94/checkpoints/epoch=26-step=1080.ckpt', # change this for loading a file and using "test", if you want training, keep None
    'epochs': 50, # including the pretrain epochs - no adding up
    'img_size' : 28, #168, # keep mnist at original size, training didn't work when i increased the size ... # MNIST/MedMNIST 28 × 28 Pixel
    'batch_size': 2, # 128, # the higher the batch_size the faster the training - every iteration adds A LOT OF comp cost
    'log_every_n_steps' : 4, # lightning default: 50 # needs to be bigger than the amount of steps in an epoch (based on trainset size and batchsize)
    'device': "cuda:0" if torch.cuda.is_available() else "cpu",
    'num_workers' : 0, # 18, # 18 for computer, 0 for laptop
    'train_size' : (2 * 50), # total or percentage
    'val_size' : (2 * 50), # total or percentage
    'test_size' : 10, # total or percentage - 0 for all
}

print("train kwargs", train_kwargs)
print("model kwargs", model_kwargs)

kwargs = {'train_kwargs':train_kwargs, 'model_kwargs':model_kwargs}


train kwargs {'result_path': 'examples/example_results', 'exp_name': 'debug_mnist', 'load_ckpt_file': 'version_18/checkpoints/epoch=4-unpruned=192-val_f1=0.12.ckpt', 'epochs': 50, 'img_size': 28, 'batch_size': 2, 'log_every_n_steps': 4, 'device': 'cuda:0', 'num_workers': 0, 'train_size': 100, 'val_size': 100, 'test_size': 10}
model kwargs {'in_channels': 1, 'n_classes': None, 'criterion': CrossEntropyLoss(), 'optimizer': 'sgd', 'base_lr': 0.001, 'min_lr': 1e-05, 'momentum': 0.9, 'lr_update': 100, 'cnn_in_channels': 1, 'cnn_out_channels': 256, 'cnn_kernel_size': 9, 'pc_num_capsules': 8, 'pc_in_channels': 256, 'pc_out_channels': 32, 'pc_kernel_size': 9, 'pc_num_routes': 1152, 'dc_num_capsules': 10, 'dc_num_routes': 1152, 'dc_in_channels': 8, 'dc_out_channels': 16, 'input_width': 28, 'input_height': 28}


# Data

In [4]:
if 'oct' in train_kwargs['exp_name']:
    # OCTMINST
    data = DataOCTMNIST(train_kwargs, model_kwargs)
        
elif 'retina' in train_kwargs['exp_name']:
    # RetinaMNIST
    data = DataRetinaMNIST(train_kwargs, model_kwargs)
    
else:
    # MNIST
    data = DataMNIST(train_kwargs, model_kwargs)
    
print(model_kwargs['n_classes'])

MNIST classes: ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
10


# Run

## Lightning - should become lightning later on

In [5]:
def train(model, optimizer, train_loader, epoch):
    capsule_net = model
    capsule_net.train()
    n_batch = len(list(enumerate(train_loader)))
    total_loss = 0
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        
        # one hot
        try:
            # print(target)
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        except: 
            # print(target.squeeze(1))
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.squeeze(1))
        
        # print(target)
        # tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        # [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])
        
        data, target = Variable(data), Variable(target)

        data, target = data.to(train_kwargs["device"]), target.to(train_kwargs["device"])

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        
        path = f"examples/example_results/capsnet/rec_train_ep{epoch}_cl{np.argmax(target.data.cpu().numpy(), 1)[0]}.png"
        if (epoch % 5) == 0 and not(os.path.isfile(path)):
            torchvision.utils.save_image(reconstructions[0], path)
        elif epoch == 1:
            path = f"examples/example_results/capsnet/rec_train_ep{epoch}_id{batch_id}_cl{np.argmax(target.data.cpu().numpy(), 1)[0]}.png"
            torchvision.utils.save_image(reconstructions[0], path)
            
        
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()
        correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
        train_loss = loss.item()
        total_loss += train_loss
        if batch_id % 100 == 0:
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,
                train_kwargs["epochs"],
                batch_id + 1,
                n_batch,
                correct / float(train_kwargs["batch_size"]),
                train_loss / float(train_kwargs["batch_size"])
                ))
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,train_kwargs["epochs"],total_loss / len(train_loader.dataset)))


def val(capsule_net, val_dataloader, epoch):
    capsule_net.eval()
    val_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(val_dataloader):
    
        # one hot
        try:
            # print(target)
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        except: 
            # print(target.squeeze(1))
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.squeeze(1))
        
        # print(target)
        # tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        # [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])
        
        data, target = Variable(data), Variable(target)

        data, target = data.to(train_kwargs["device"]), target.to(train_kwargs["device"])

        output, reconstructions, masked = capsule_net(data)
        
        """
        2 = batch size
        1 = channel
        28x28 = img size
        10 = classes
        16 = digit output channels 
        torch.Size([2, 10, 16, 1]) = output
        torch.Size([2, 1, 28, 28]) = reconstruction
        torch.Size([2, 10]) = masked - the thing we want for classification
        """
        
        path = f"examples/example_results/capsnet/rec_val_ep{epoch}_cl{np.argmax(target.data.cpu().numpy(), 1)[0]}.png"
        if (epoch % 5) == 0 and not(os.path.isfile(path)):
            torchvision.utils.save_image(reconstructions[0], path)
        
        #print(output.shape)
        #print(reconstructions.shape)
        #print(masked.shape)
        
        #print(output)
        #print(reconstructions)
        #print(masked)
        
        loss = capsule_net.loss(data, output, target, reconstructions)

        val_loss += loss.item()
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    tqdm.write(
        "Epoch: [{}/{}], val accuracy: {:.6f}, loss: {:.6f}".format(epoch, train_kwargs["epochs"], correct / len(val_dataloader.dataset),
                                                                  val_loss / len(val_dataloader)))


## run training

In [6]:
capsule_net = CapsNet(model_kwargs)
capsule_net = torch.nn.DataParallel(capsule_net)

capsule_net = capsule_net.to(train_kwargs["device"])
capsule_net = capsule_net.module

optimizer = torch.optim.Adam(capsule_net.parameters())

for e in range(1, train_kwargs["epochs"] + 1):
    train(capsule_net, optimizer, data.train_dataloader, e)
    val(capsule_net, data.val_dataloader, e)

in_channels 1


  2%|█▋                                                                                 | 1/50 [00:04<03:35,  4.41s/it]

Epoch: [1/50], Batch: [1/50], train accuracy: 0.000000, loss: 0.449780


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00,  2.18it/s]


Epoch: [1/50], train loss: 0.417020
Epoch: [1/50], val accuracy: 0.500000, loss: 0.761273


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.84it/s]

Epoch: [2/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.242838


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.76it/s]


Epoch: [2/50], train loss: 0.326513
Epoch: [2/50], val accuracy: 0.540000, loss: 0.679470


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [3/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.313528


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [3/50], train loss: 0.232758
Epoch: [3/50], val accuracy: 0.570000, loss: 0.633437


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [4/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.136800


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [4/50], train loss: 0.150281
Epoch: [4/50], val accuracy: 0.620000, loss: 0.593660


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [5/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.065787


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.74it/s]


Epoch: [5/50], train loss: 0.084486
Epoch: [5/50], val accuracy: 0.590000, loss: 0.559380


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.74it/s]

Epoch: [6/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.039936


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.75it/s]


Epoch: [6/50], train loss: 0.051477
Epoch: [6/50], val accuracy: 0.600000, loss: 0.552253


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [7/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.039036


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [7/50], train loss: 0.029501
Epoch: [7/50], val accuracy: 0.620000, loss: 0.540436


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [8/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.019381


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [8/50], train loss: 0.018175
Epoch: [8/50], val accuracy: 0.670000, loss: 0.530858


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [9/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000225


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [9/50], train loss: 0.010293
Epoch: [9/50], val accuracy: 0.620000, loss: 0.547327


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [10/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.011777


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.76it/s]


Epoch: [10/50], train loss: 0.007750
Epoch: [10/50], val accuracy: 0.640000, loss: 0.518179


  2%|█▋                                                                                 | 1/50 [00:00<00:18,  2.67it/s]

Epoch: [11/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.001508


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.74it/s]


Epoch: [11/50], train loss: 0.006882
Epoch: [11/50], val accuracy: 0.620000, loss: 0.524268


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.74it/s]

Epoch: [12/50], Batch: [1/50], train accuracy: 0.000000, loss: 0.000209


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.76it/s]


Epoch: [12/50], train loss: 0.005208
Epoch: [12/50], val accuracy: 0.650000, loss: 0.553735


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [13/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000683


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [13/50], train loss: 0.002482
Epoch: [13/50], val accuracy: 0.680000, loss: 0.540387


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [14/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.011073


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [14/50], train loss: 0.003756
Epoch: [14/50], val accuracy: 0.670000, loss: 0.519422


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.80it/s]

Epoch: [15/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000197


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.73it/s]


Epoch: [15/50], train loss: 0.002070
Epoch: [15/50], val accuracy: 0.640000, loss: 0.528255


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.76it/s]

Epoch: [16/50], Batch: [1/50], train accuracy: 0.500000, loss: 0.000175


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [16/50], train loss: 0.002243
Epoch: [16/50], val accuracy: 0.650000, loss: 0.548680


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [17/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000166


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [17/50], train loss: 0.002290
Epoch: [17/50], val accuracy: 0.640000, loss: 0.527791


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [18/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000237


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [18/50], train loss: 0.004297
Epoch: [18/50], val accuracy: 0.640000, loss: 0.536498


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.80it/s]

Epoch: [19/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.006943


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [19/50], train loss: 0.004044
Epoch: [19/50], val accuracy: 0.680000, loss: 0.529609


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [20/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000234


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.77it/s]


Epoch: [20/50], train loss: 0.001268
Epoch: [20/50], val accuracy: 0.640000, loss: 0.523622


  2%|█▋                                                                                 | 1/50 [00:00<00:18,  2.68it/s]

Epoch: [21/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.003315


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.77it/s]


Epoch: [21/50], train loss: 0.002465
Epoch: [21/50], val accuracy: 0.660000, loss: 0.514838


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.83it/s]

Epoch: [22/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.001955


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [22/50], train loss: 0.002757
Epoch: [22/50], val accuracy: 0.670000, loss: 0.533114


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.87it/s]

Epoch: [23/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000173


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [23/50], train loss: 0.001824
Epoch: [23/50], val accuracy: 0.680000, loss: 0.517482


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.80it/s]

Epoch: [24/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000212


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [24/50], train loss: 0.002166
Epoch: [24/50], val accuracy: 0.630000, loss: 0.551674


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.74it/s]

Epoch: [25/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000240


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.73it/s]


Epoch: [25/50], train loss: 0.001715
Epoch: [25/50], val accuracy: 0.650000, loss: 0.539761


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.75it/s]

Epoch: [26/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000176


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [26/50], train loss: 0.001627
Epoch: [26/50], val accuracy: 0.690000, loss: 0.527488


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [27/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000226


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.81it/s]


Epoch: [27/50], train loss: 0.001220
Epoch: [27/50], val accuracy: 0.650000, loss: 0.529452


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.83it/s]

Epoch: [28/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000214


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [28/50], train loss: 0.001085
Epoch: [28/50], val accuracy: 0.660000, loss: 0.534564


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [29/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.003471


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [21:41<00:00, 26.02s/it]


Epoch: [29/50], train loss: 0.001000
Epoch: [29/50], val accuracy: 0.640000, loss: 0.536816


  2%|█▋                                                                                 | 1/50 [00:00<00:23,  2.09it/s]

Epoch: [30/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000178


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.65it/s]


Epoch: [30/50], train loss: 0.001977
Epoch: [30/50], val accuracy: 0.640000, loss: 0.555683


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.75it/s]

Epoch: [31/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000169


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.78it/s]


Epoch: [31/50], train loss: 0.000479
Epoch: [31/50], val accuracy: 0.650000, loss: 0.515539


  2%|█▋                                                                                 | 1/50 [00:00<00:20,  2.34it/s]

Epoch: [32/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.019824


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.70it/s]


Epoch: [32/50], train loss: 0.000731
Epoch: [32/50], val accuracy: 0.680000, loss: 0.518116


  2%|█▋                                                                                 | 1/50 [00:00<00:19,  2.49it/s]

Epoch: [33/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000191


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.68it/s]


Epoch: [33/50], train loss: 0.000898
Epoch: [33/50], val accuracy: 0.680000, loss: 0.509952


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.74it/s]

Epoch: [34/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000162


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.75it/s]


Epoch: [34/50], train loss: 0.001079
Epoch: [34/50], val accuracy: 0.650000, loss: 0.528951


  2%|█▋                                                                                 | 1/50 [00:00<00:18,  2.61it/s]

Epoch: [35/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000221


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.74it/s]


Epoch: [35/50], train loss: 0.001097
Epoch: [35/50], val accuracy: 0.640000, loss: 0.534211


  2%|█▋                                                                                 | 1/50 [00:00<00:18,  2.67it/s]

Epoch: [36/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000187


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.79it/s]


Epoch: [36/50], train loss: 0.001277
Epoch: [36/50], val accuracy: 0.670000, loss: 0.511347


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [37/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000141


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.79it/s]


Epoch: [37/50], train loss: 0.002579
Epoch: [37/50], val accuracy: 0.620000, loss: 0.550007


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [38/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000156


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.80it/s]


Epoch: [38/50], train loss: 0.001052
Epoch: [38/50], val accuracy: 0.650000, loss: 0.538540


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [39/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.008193


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.80it/s]


Epoch: [39/50], train loss: 0.001301
Epoch: [39/50], val accuracy: 0.670000, loss: 0.531057


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.75it/s]

Epoch: [40/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000157


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.74it/s]


Epoch: [40/50], train loss: 0.000919
Epoch: [40/50], val accuracy: 0.640000, loss: 0.525094


  2%|█▋                                                                                 | 1/50 [00:00<00:18,  2.71it/s]

Epoch: [41/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000163


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.79it/s]


Epoch: [41/50], train loss: 0.001231
Epoch: [41/50], val accuracy: 0.640000, loss: 0.549130


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.74it/s]

Epoch: [42/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000159


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.80it/s]


Epoch: [42/50], train loss: 0.001821
Epoch: [42/50], val accuracy: 0.670000, loss: 0.506703


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.86it/s]

Epoch: [43/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000131


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00,  2.77it/s]


Epoch: [43/50], train loss: 0.000230
Epoch: [43/50], val accuracy: 0.680000, loss: 0.487465


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.75it/s]

Epoch: [44/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000111


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.79it/s]


Epoch: [44/50], train loss: 0.000703
Epoch: [44/50], val accuracy: 0.680000, loss: 0.515782


  2%|█▋                                                                                 | 1/50 [00:00<00:25,  1.91it/s]

Epoch: [45/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000232


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:19<00:00,  2.60it/s]


Epoch: [45/50], train loss: 0.002645
Epoch: [45/50], val accuracy: 0.710000, loss: 0.538119


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.79it/s]

Epoch: [46/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.018793


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.79it/s]


Epoch: [46/50], train loss: 0.000698
Epoch: [46/50], val accuracy: 0.680000, loss: 0.509403


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.77it/s]

Epoch: [47/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000192


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.80it/s]


Epoch: [47/50], train loss: 0.000894
Epoch: [47/50], val accuracy: 0.710000, loss: 0.521909


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.86it/s]

Epoch: [48/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000129


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.78it/s]


Epoch: [48/50], train loss: 0.000721
Epoch: [48/50], val accuracy: 0.670000, loss: 0.528994


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.81it/s]

Epoch: [49/50], Batch: [1/50], train accuracy: 0.500000, loss: 0.000220


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.82it/s]


Epoch: [49/50], train loss: 0.001196
Epoch: [49/50], val accuracy: 0.620000, loss: 0.501107


  2%|█▋                                                                                 | 1/50 [00:00<00:17,  2.82it/s]

Epoch: [50/50], Batch: [1/50], train accuracy: 1.000000, loss: 0.000139


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [00:19<00:00,  2.55it/s]


Epoch: [50/50], train loss: 0.001031
Epoch: [50/50], val accuracy: 0.670000, loss: 0.523572
