# CapsNet

Source: https://github.com/jindongwang/Pytorch-CapsuleNet/blob/master/test_capsnet.py

In [1]:
import sys 
sys.path.insert(0, "helper")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from tqdm import tqdm

# =============================================================================
# datasceyence
# =============================================================================
from helper.visualisation.feature_map import *
from helper.model.capsnet import CapsNet

from helper.data.mnist import DataMNIST
from helper.data.retinamnist import DataRetinaMNIST
from helper.data.octmnist import DataOCTMNIST

#train_kwargs["device"] = True if torch.cuda.is_available() else False
#train_kwargs["batch_size"] = 5
# train_kwargs["epochs"] = 30
#LEARNING_RATE = 0.01
#MOMENTUM = 0.9



In [2]:
seed = 1997 # was 19 before

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

print("cuda available:", torch.cuda.is_available())

debug_model = False

print('torch 2.0.0 ==', torch.__version__=='2.0.0' , '->', torch.__version__)

cuda available: True
torch 2.0.0 == False -> 1.13.1


# Settings

In [3]:
model_kwargs = { # for 1 Channel
    'in_channels' : 1,
    'n_classes': None, # filled in the dataset
    'criterion': torch.nn.CrossEntropyLoss(),# torch.nn.BCEWithLogitsLoss(),
    'optimizer': "sgd", # sgd adamw
    'base_lr': 0.001,
    'min_lr' : 0.00001,
    'momentum' : 0.9,
    'lr_update' : 100,
    
    # CNN (cnn)
    'cnn_in_channels' : 1, # duplicate ...
    'cnn_out_channels' : 256,
    'cnn_kernel_size' : 9,

    # Primary Capsule (pc)
    'pc_num_capsules' : 8,
    'pc_in_channels' : 256,
    'pc_out_channels' : 32,
    'pc_kernel_size' : 9,
    'pc_num_routes' : 32 * 6 * 6,

    # Digit Capsule (dc)
    'dc_num_capsules' : 10,
    'dc_num_routes' : 32 * 6 * 6,
    'dc_in_channels' : 8,
    'dc_out_channels' : 16,

    # Decoder
    'input_width' : 28,
    'input_height' : 28,
}


train_kwargs = {
    'result_path': "examples/example_results", # "example_results/lightning_logs", # not in use??
    'exp_name': "debug_oct_no_fc", # must include oct or retina
    'load_ckpt_file' : 'version_18/checkpoints/epoch=4-unpruned=192-val_f1=0.12.ckpt', # "version_0/checkpoints/epoch=94-unpruned=1600-val_f1=0.67.ckpt", # 'version_94/checkpoints/epoch=26-step=1080.ckpt', # change this for loading a file and using "test", if you want training, keep None
    'epochs': 5, # including the pretrain epochs - no adding up
    'img_size' : 28, #168, # keep mnist at original size, training didn't work when i increased the size ... # MNIST/MedMNIST 28 × 28 Pixel
    'batch_size': 2, # 128, # the higher the batch_size the faster the training - every iteration adds A LOT OF comp cost
    'log_every_n_steps' : 4, # lightning default: 50 # needs to be bigger than the amount of steps in an epoch (based on trainset size and batchsize)
    'device': "cuda",
    'num_workers' : 0, # 18, # 18 for computer, 0 for laptop
    'train_size' : (2 * 10), # total or percentage
    'val_size' : (2 * 10), # total or percentage
    'test_size' : 10, # total or percentage - 0 for all
}

print("train kwargs", train_kwargs)
print("model kwargs", model_kwargs)

kwargs = {'train_kwargs':train_kwargs, 'model_kwargs':model_kwargs}


train kwargs {'result_path': 'examples/example_results', 'exp_name': 'debug_oct_no_fc', 'load_ckpt_file': 'version_18/checkpoints/epoch=4-unpruned=192-val_f1=0.12.ckpt', 'epochs': 5, 'img_size': 28, 'batch_size': 2, 'log_every_n_steps': 4, 'device': 'cuda', 'num_workers': 0, 'train_size': 20, 'val_size': 20, 'test_size': 10}
model kwargs {'in_channels': 1, 'n_classes': None, 'criterion': CrossEntropyLoss(), 'optimizer': 'sgd', 'base_lr': 0.001, 'min_lr': 1e-05, 'momentum': 0.9, 'lr_update': 100, 'cnn_in_channels': 1, 'cnn_out_channels': 256, 'cnn_kernel_size': 9, 'pc_num_capsules': 8, 'pc_in_channels': 256, 'pc_out_channels': 32, 'pc_kernel_size': 9, 'pc_num_routes': 1152, 'dc_num_capsules': 10, 'dc_num_routes': 1152, 'dc_in_channels': 8, 'dc_out_channels': 16, 'input_width': 28, 'input_height': 28}


# Dataloaders

In [4]:
data = DataOCTMNIST(train_kwargs, model_kwargs)

Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\Prinzessin\.medmnist\octmnist.npz
python_class : OCTMNIST
description : The OCTMNIST is based on a prior dataset of 109,309 valid optical coherence tomography (OCT) images for retinal diseases. The dataset is comprised of 4 diagnosis categories, leading to a multi-class classification task. We split the source training set with a ratio of 9:1 into training and validation set, and use its source validation set as the test set. The source images are gray-scale, and their sizes are (384−1,536)×(277−512). We center-crop the images and resize them into 1×28×28.
url : https://zenodo.org/record/6496656/files/octmnist.npz?download=1
MD5 : c68d92d5b585d8d81f7112f81e2d0842
task : multi-class
label : {'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'}

# Run

## Lightning - should become lightning later on

In [5]:
def train(model, optimizer, train_loader, epoch):
    capsule_net = model
    capsule_net.train()
    n_batch = len(list(enumerate(train_loader)))
    total_loss = 0
    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
        
        # one hot
        try:
            # print(target)
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        except: 
            # print(target.squeeze(1))
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.squeeze(1))
        
        # print(target)
        # tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        # [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])
        
        data, target = Variable(data), Variable(target)

        data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()
        correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
        train_loss = loss.item()
        total_loss += train_loss
        if batch_id % 100 == 0:
            tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format(
                epoch,
                train_kwargs["epochs"],
                batch_id + 1,
                n_batch,
                correct / float(train_kwargs["batch_size"]),
                train_loss / float(train_kwargs["batch_size"])
                ))
    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,train_kwargs["epochs"],total_loss / len(train_loader.dataset)))


def val(capsule_net, val_dataloader, epoch):
    capsule_net.eval()
    val_loss = 0
    correct = 0
    for batch_id, (data, target) in enumerate(val_dataloader):
    
        # one hot
        try:
            # print(target)
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
        except: 
            # print(target.squeeze(1))
            # tensor([6, 5])
            target = torch.sparse.torch.eye(10).index_select(dim=0, index=target.squeeze(1))
        
        # print(target)
        # tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        # [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])
        
        data, target = Variable(data), Variable(target)

        data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        val_loss += loss.item()
        correct += sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                       np.argmax(target.data.cpu().numpy(), 1))

    tqdm.write(
        "Epoch: [{}/{}], val accuracy: {:.6f}, loss: {:.6f}".format(epoch, train_kwargs["epochs"], correct / len(val_dataloader.dataset),
                                                                  val_loss / len(val_dataloader)))


## run training

In [6]:
capsule_net = CapsNet(model_kwargs)
capsule_net = torch.nn.DataParallel(capsule_net)

capsule_net = capsule_net.cuda()
capsule_net = capsule_net.module

optimizer = torch.optim.Adam(capsule_net.parameters())

for e in range(1, train_kwargs["epochs"] + 1):
    train(capsule_net, optimizer, data.train_dataloader, e)
    val(capsule_net, data.val_dataloader, e)

in_channels 1


 10%|████████▎                                                                          | 1/10 [00:09<01:25,  9.52s/it]

Epoch: [1/5], Batch: [1/10], train accuracy: 0.000000, loss: 0.450048


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:13<00:00,  1.30s/it]


Epoch: [1/5], train loss: 0.442028
Epoch: [1/5], val accuracy: 0.350000, loss: 0.811564


 10%|████████▎                                                                          | 1/10 [00:00<00:03,  2.59it/s]

Epoch: [2/5], Batch: [1/10], train accuracy: 0.500000, loss: 0.403726


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.74it/s]


Epoch: [2/5], train loss: 0.391593
Epoch: [2/5], val accuracy: 0.350000, loss: 0.777948


 10%|████████▎                                                                          | 1/10 [00:00<00:03,  2.79it/s]

Epoch: [3/5], Batch: [1/10], train accuracy: 1.000000, loss: 0.369839


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.69it/s]


Epoch: [3/5], train loss: 0.346521
Epoch: [3/5], val accuracy: 0.350000, loss: 0.771700


 10%|████████▎                                                                          | 1/10 [00:00<00:03,  2.77it/s]

Epoch: [4/5], Batch: [1/10], train accuracy: 1.000000, loss: 0.207192


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.74it/s]


Epoch: [4/5], train loss: 0.294726
Epoch: [4/5], val accuracy: 0.350000, loss: 0.733908


 10%|████████▎                                                                          | 1/10 [00:00<00:03,  2.76it/s]

Epoch: [5/5], Batch: [1/10], train accuracy: 0.500000, loss: 0.206109


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.76it/s]


Epoch: [5/5], train loss: 0.224895
Epoch: [5/5], val accuracy: 0.350000, loss: 0.723684
