In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m575.4 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [3]:
from einops import rearrange, repeat,reduce
from einops.layers.torch import Rearrange
from torch import linalg as LA

In [4]:
import torch
import time
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

from torch import linalg as LA
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

loc_time = time.strftime("%H%M%S", time.localtime()) 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ratio = 8
# c_outs= 128
reduction = 1 # default:1
n_iter = 0

def squash(inputs,ep_iter=n_iter,total=100):
    beta = 1.45 - 0.22* (ep_iter/total)
    mag_sq = torch.sum(inputs**2, dim=2, keepdim=True)
    mag = torch.sqrt(mag_sq)
    s = (mag_sq / (beta + mag_sq)) * (inputs / mag)
    return s


class CapsNet(nn.Module):
    def __init__(self,conv_inputs,
                 num_classes=7,
                 init_weights=False,
                 conv_outputs = 128,
                 primary_units = 8,#8,
                 primary_unit_size = 576,# 16 * 6 * 6,
                 output_unit_size = 16,):
        super().__init__()
        
        self.Convolution = nn.Sequential(nn.Conv2d(conv_inputs, conv_outputs, 21,stride=2),
                                        nn.BatchNorm2d(conv_outputs),
                                        nn.ReLU(inplace=True),)

        

        # self.Pool = nn.FractionalMaxPool2d(3, output_size=(20))
        self.Pool = nn.AdaptiveMaxPool2d(20)
        #Attention
        self.CBAM = Conv_CBAM(conv_outputs,conv_outputs)
        #Capsule
        self.primary = Primary_Caps(in_channels=conv_outputs,#128
                                    caps_units=primary_units,#8
                                    )

        self.digits = Digits_Caps(in_units=primary_units,#8
                                   in_channels=primary_unit_size,#16*6*6=576
                                   num_units=num_classes,#classification_num
                                   unit_size=output_unit_size,#16
                                   )
        if init_weights:
            self._initialize_weights()
        
    def forward(self, x):
        x = self.Convolution(x)
        x = self.Pool(x)      
        x = self.CBAM(x)
        out = self.digits(self.primary(x))
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):# or isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 0, 0.01)
                # nn.init.constant_(m.bias, 0)
    #margin_loss           
    def loss(self,img_input, target, epoch=1,epoch_total=100, size_average=True):
    # def loss(self,img_input, target, size_average=True):
        batch_size = img_input.size(0)
        # coefficient = epoch/epoch_total
        
        v_mag = LA.norm(img_input,ord=2,dim=(2,3),keepdim=True) #largest singular value
        zero = Variable(torch.zeros(1)).to(device)
        m_plus  = 0.9 #- (coefficient/10) 
        m_minus =0.1 #+ (coefficient/10)
        max_l = torch.max(m_plus - v_mag, zero).view(batch_size, -1)**2
        max_r = torch.max(v_mag - m_minus, zero).view(batch_size, -1)**2
        
        # init_lambda = 0.5 
        loss_lambda = 0.5 #+ coefficient
        # print(f"lambda:{loss_lambda}")
        T_c = target
        L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
        L_c = torch.sum(L_c,1)
        
        if size_average:
            L_c = torch.mean(L_c)

        return L_c

    def update_n_iter(self, ep_iter):
        if ep_iter > 100 and ep_iter % 10 == 0:
            ep_iter -=100
            self.primary.n_iter = ep_iter
            self.digits.n_iter = ep_iter
            beta = 1.45 - 0.22* (ep_iter/100)
            print(f"beta:{beta}")
        
        
class Primary_Caps(nn.Module):
    def __init__(self, in_channels, caps_units):
        super(Primary_Caps, self).__init__()
        self.n_iter = n_iter
        self.in_channels = in_channels
        self.caps_units = caps_units
        
        def create_conv_unit(unit_idx):
            unit = ConvUnit(in_channels=in_channels)
            self.add_module("Caps_" + str(unit_idx), unit)
            return unit
        self.units = [create_conv_unit(i) for i in range(self.caps_units)]
   
    #no_routing
    def forward(self, x):
        # Get output for each unit.
        # Each will be (batch, channels, height, width).
        u = [self.units[i](x) for i in range(self.caps_units)]
        # Stack all unit outputs (batch, unit, channels, height, width).
        u = torch.stack(u, dim=1)
        # Flatten to (batch, unit, output).
        u = u.view(x.size(0), self.caps_units, -1)

        return squash(u,self.n_iter)
    
class Digits_Caps(nn.Module):
    def __init__(self, in_units, in_channels, num_units, unit_size):
        super(Digits_Caps, self).__init__()
        self.n_iter = n_iter
        self.in_units = in_units
        self.in_channels = in_channels
        self.num_units = num_units
        
        self.W = nn.Parameter(torch.randn(1, in_channels, self.num_units, unit_size, in_units))
        # self.w = [1,576,7,16,8]
        
    #routing
    def forward(self, x):
        batch_size = x.size(0)    
        # (batch, in_units, features) -> (batch, features, in_units)
        x = x.transpose(1, 2)        
        # (batch, features, in_units) -> (batch, features, num_units, in_units, 1)
        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)        
        # (batch, features, in_units, unit_size, num_units)
        W = torch.cat([self.W] * batch_size, dim=0)
        # Transform inputs by weight matrix.
        # (batch_size, features, num_units, unit_size, 1)
        u_hat = torch.matmul(W, x)
        # Initialize routing logits to zero.
        b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1)).to(device)
        
        num_iterations = 3
        for iteration in range(num_iterations):
            # Convert routing logits to softmax.
            # (batch, features, num_units, 1, 1)
            #c_ij = F.softmax(b_ij, dim=0)
            c_ij = b_ij.softmax(dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            # Apply routing (c_ij) to weighted inputs (u_hat).
            # (batch_size, 1, num_units, unit_size, 1)
            s_j = torch.sum(c_ij * u_hat, dim=1, keepdim=True)

            # (batch_size, 1, num_units, unit_size, 1)
            v_j = squash(s_j,self.n_iter)

            # (batch_size, features, num_units, unit_size, 1)
            v_j1 = torch.cat([v_j] * self.in_channels, dim=1)

            # (batch_size, features, num_units, 1)
            u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)

            # Update b_ij (routing)
            b_ij = u_vj1

        # return v_j.squeeze(1)
        return rearrange(v_j.squeeze(1), 'b c (g h) w -> b c g (h w)',g=4)
                
class ConvUnit(nn.Module):
    def __init__(self, in_channels):
        super(ConvUnit, self).__init__()
        Caps_out = in_channels // ratio# 16
        self.Cpas = nn.Sequential(
                        # nn.Conv2d(in_channels,Caps_out,(9,9),stride=2,groups=Caps_out, bias=False),
                        nn.Conv2d(in_channels,in_channels,(9,1),stride=1, bias=False),
                        nn.Conv2d(in_channels,Caps_out,(1,9),stride=2,groups=Caps_out, bias=False),
                    )

    def forward(self, x):
        output = self.Cpas(x)
        return output

class Conv_CBAM(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super(Conv_CBAM, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.Hardswish() if act else nn.Identity()
        self.ca = ChannelAttention(c2, reduction=reduction)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        x = self.ca(x) * x
        x = self.sa(x) * x
        return x
    
def autopad(k, p=None):  # kernel, padding
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
    return p

# SAM
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3): # default:3
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7)
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size,padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # This is different from the paper[S. Woo, et al. "CBAM: Convolutional Block Attention Module,"].
        avg_out = torch.mean(x, dim=1, keepdim=True)#The different channels are averaged and converted to 1 channel.
        max_out, _ = torch.max(x, dim=1, keepdim=True)#Maximizing the different channels and making them 1-channel.
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

# CAM
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        me_c = channels // reduction
        self.avg_pool = nn.AdaptiveAvgPool2d(1) 
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1   = nn.Conv2d(channels, me_c, 1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2   = nn.Conv2d(me_c, channels, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

In [5]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


In [6]:
import torch
import numpy as np
import seaborn as sns
import prettytable
import matplotlib.pyplot as plt
from thop.profile import profile
import time, random

offset = 16 #draw_size_acc
class  ImageShow(object):
    def __init__(self,
                train_loss_list,train_acc_list,
                test_loss_list,test_acc_list,
                test_auc_list,
                val_loss_list,val_acc_list):        
        self.trainll, self.trainacl = train_loss_list, train_acc_list
        self.testll, self.testacl = test_loss_list, test_acc_list
        self.testauc = test_auc_list
        self.valll, self.valacl = val_loss_list, val_acc_list
        
    def train(self,opt="Loss",write=True,custom_path=None,img_title=None,suf=None):
        if opt == 'Acc':
            img_portray(opt=opt,write=write,dates=self.trainacl,
                        #location='upper left',
                        label='Train_Acc',col='red',
                        img_title=img_title,suf=suf)
        elif opt == 'Loss':
            img_portray(opt=opt,write=write,dates=self.trainll,
                        #location='upper right',
                        linestyle="--",
                        label='Train_Loss',col='green',
                        img_title=img_title,suf=suf)
        if write:
            save_images(img_title=img_title,suf=suf,opt=opt)
        plt.show()
        
    def test(self,opt='Acc',write=True,custom_path=None,img_title=None,suf=None,**kwargs):
        if opt == 'Acc':
            img_portray(opt=opt,write=write,dates=self.testacl,
                        #location='upper left',
                        label='Test_Acc',col='red',
                        xlabel="Batch_Size",img_title=img_title,suf=suf)
        elif opt == 'Loss':
            img_portray(opt=opt,write=write,dates=self.testll,
                        #location='upper right',
                        linestyle="-.",
                        label='Test_Loss',col='green',
                        img_title=img_title,suf=suf)
        if write:
            save_images(split='test',img_title=img_title,suf=suf,opt=opt)
        plt.show()
        
    def val(self,opt='Acc',write=True,custom_path=None,img_title=None,suf=None):
        if opt == 'Acc':
            img_portray(opt=opt,write=write,dates=self.valacl,
                        linestyle="dotted",col='red',
                        label='Val_Acc',#location='upper left',
                        img_title=img_title,suf=suf)
        elif opt == 'Loss':
            img_portray(opt=opt,write=write,dates=self.valll,
                        linestyle="-.",col='green',
                        label='Val_Loss',#location='upper right',
                        img_title=img_title,suf=suf)
        if write:
            save_images(split='Val',img_title=img_title,suf=suf,opt=opt)
        plt.show()
        
    def conclusion(self,opt="test",img_title=None):
        if opt == "test" and len(self.testacl) != 0:
            print(f'\033[31m=================Conclusion====================\033[0m')
            best_idx = self.testacl.index(max(self.testacl))
            # val_idx = (best_idx+1)-1
            best_epoch = (best_idx+1)
            print(f"Dataset:[\033[1;31m{img_title}\033[0m]")
            print(f"Best_Epoch [\033[1;31m{best_epoch}\033[0m]")
            # print("[Train] loss {self.trainll[best_epoch-1]};")
            print(f"[Test] \033[31mACC:{round(float(self.testacl[best_idx]),2)}%\033[0m.")
            # Loss:{self.testll[best_idx]}, AUC:{round(float(self.testauc[best_idx]),2)}%
            # print(f"[Test]:\033[32mVal_ACC:{round(float(max(self.testauc)),2)}%\033[0m.")
        if opt == "val" and len(self.valacl) != 0:
            print(f'\033[31m=================Conclusion====================\033[0m')
            best_idx = self.valacl.index(max(self.valacl))
            best_epoch = (best_idx+1)
            print(f"Dataset:[\033[1;31m{img_title}\033[0m]")
            print(f"Best_Epoch [\033[1;31m{best_epoch}\033[0m]")
            print(f"[Val] \033[31mACC:{round(float(self.valacl[best_idx]),2)}%\033[0m.")

        if opt == "auc" and len(self.testauc) != 0:
            print(f'\033[31m=================Conclusion====================\033[0m')
            best_idx = self.testauc.index(max(self.testauc))
            val_idx = (best_idx+1)-1
            best_epoch = (best_idx+1)
            print(f"Dataset:[\033[1;31m{img_title}\033[0m]")
            print(f"Best_Epoch [\033[1;31m{best_epoch}\033[0m]\n[Train] loss:{self.trainll[best_epoch-1]};")
            print(f"[Test] Loss:{self.testll[best_idx]}, \033[32mACC:{round(float(self.testacl[best_idx]),2)}%\033[0m.")
            print(f"[Test]:\033[32m AUC:{round(float(self.testauc[best_idx]),2)}%\033[0m.")
            

def draw_size_acc(data_dict,custom_path='./tmp',write=True,fn='Batch_Size',img_title=None,suf=None):
    sx=[]
    sy=[]

    for i in range(len(data_dict)):
        x=sorted(data_dict.items(), key=lambda x: x[0])[i][0]
        y=sorted(data_dict.items(), key=lambda x: x[0])[i][1]
        sx.append(x)
        sy.append(y)

    y_max = np.argmax(sy)
    img_max = y_max + offset
    show_max = round(float(sy[y_max]),3)
    plt.plot(img_max,show_max ,'8')
    plt.annotate(show_max,xy=(img_max,show_max),xytext=(img_max,show_max))
    
    plt.style.use("seaborn-paper")
    plt.plot(sx, sy,label="Test_Data")
    plt.ylabel("Accuracy")
    plt.xlabel(fn)
    plt.legend(loc="best") 
    if write:
        plt.savefig(f'{custom_path}/{img_title}/{suf}/{fn}.png',dpi=300)
    plt.show()
    
    
def one_hot(x, length):
    batch_size = x.size(0)
    x_one_hot = torch.zeros(batch_size, length)
    for i in range(batch_size):
        x_one_hot[i, x[i]] = 1.0
    return x_one_hot


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
    
def confusion_matrix(evl_result,n_cla,cla_dict,data,img_title=None,suf=None):
    plt.figure(figsize=(12,9))
    sb = range(n_cla)

    sns.heatmap(evl_result,cmap="Blues",
                annot=True,cbar=True,
                fmt="g", annot_kws={"size": 20})
    plt.yticks([index + 0.5 for index in sb],cla_dict.values(),fontsize=16)
    plt.xticks([index + 0.5 for index in sb],cla_dict.values(),fontsize=16)

    plt.title("Confusion Matrix",fontsize=24)
    cax = plt.gcf().axes[-1]
    cax.tick_params(labelsize=12)

    # vname = lambda v,nms: [ vn for vn in nms if id(v)==id(nms[vn])][0]
    # kn = vname(evl_result,locals())
    if evl_result.sum().item() == len(data):
        kn = 'test'
    else:
        kn = 'val'

    plt.savefig(f"./tmp/{img_title}/{suf}/Confusion_Matrix_{kn}.png",dpi=300)
    
    
def pff(m_name,model,inputes):

    print("%s | %s | %s | %s" % ("  Model  ", "Params(M)", "FLOPs(G)","FPS"))
    print("----------|-----------|----------|-----")

    total_ops, total_params = profile(model, (inputes,), verbose=False)
    model.eval()
    with torch.no_grad():
        torch.cuda.synchronize()
        start = time.time()
        output= model(inputes)
        torch.cuda.synchronize()
        end = time.time()
        single_fps = 1/(end-start)

    print(
        "%s |    %.2f   |   %.2f   | %.1f" % (m_name, total_params / (1000 ** 2),
                                    total_ops / (1000 ** 3),
                                    single_fps)
        )
    
    
def metrics_scores(evl_result,n_classes,cla_dict):
    P,R,F = 0,0,0
    result_table = prettytable.PrettyTable()
    result_table.field_names = ['Type','Precision', 'Recall', 'F1','Accuracy']    
    accuracy = float(torch.sum(evl_result.diagonal())/torch.sum(evl_result))  

    for i in range(n_classes):
        pre = float(evl_result[i][i] / (torch.sum(evl_result,0)[i]) + 1e-12)
        P += pre
        recall = float(evl_result[i][i] / (torch.sum(evl_result,1)[i]) + 1e-12)
        R += recall
        F1 = pre * recall * 2 / (pre + recall + 1e-12)
        F += F1
        result_table.add_row([cla_dict[i], round(pre, 4), round(recall, 4), round(F1, 4)," "])
    P_avg,R_avg,F_avg = P/n_classes,R/n_classes,F/n_classes
    result_table.add_row(["Total:",round(P_avg, 4),round(R_avg, 4),round(F_avg, 4),round(accuracy,4)])
    print(result_table)


In [7]:
!pip install torch-summary
import torch
import sys, os
import json
import torch.nn as nn  
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim

# from torch.utils.tensorboard import SummaryWriter
import prettytable
import time, random,timeit
sys.setrecursionlimit(15000)
from thop.profile import profile

from PIL import Image
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
from tqdm.notebook import tqdm
import seaborn as sns

setup_seed(3047)


Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl.metadata (18 kB)
Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [8]:
sys.path.append(os.pardir)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
img_title = "HAM10000"#"Skin_Cancer"
best_acc = 0.
eval_acc = 0.
best_train = 0.
dict_batch = {}
dict_imgSize = {}


try:
    print(len(train_acc_list))
except NameError:
    train_loss_list = []
    train_acc_list = []
    test_loss_list = []
    test_acc_list = []
    test_auc_list = []
    val_loss_list = []
    val_acc_list = []
#activate ImageShow
show = ImageShow(train_loss_list = train_loss_list,
                 train_acc_list = train_acc_list,
                test_loss_list = test_loss_list,
                test_acc_list = test_acc_list,
                test_auc_list = test_auc_list,
                val_loss_list = val_loss_list,
                val_acc_list = val_acc_list,
                )

In [9]:
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
normalize = transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
# Resize = transforms.Resize((299,299))

def get_data(mode='ALL'):
    global test_dataset,train_loader,val_loader,test_loader
    global train_num,val_num,test_num,n_classes,cla_dict
    # vt = int(trans)
    data_transform = {
        "train": transforms.Compose([#Resize,
                                     transforms.RandomVerticalFlip(),
                                     transforms.ToTensor(),
                                     normalize]),
        "val": transforms.Compose([#Resize,
                                   transforms.ToTensor(),
                                   normalize]),
        "test": transforms.Compose([#Resize,
                                    transforms.ToTensor(),
                                    normalize]),}    
    if mode == 'ALL':
        train_dataset = datasets.ImageFolder(root=train_dir,transform=data_transform["train"])
        val_dataset = datasets.ImageFolder(root=val_dir,transform=data_transform["val"])
        test_dataset = datasets.ImageFolder(root=test_dir,transform=data_transform["test"])

        train_num = len(train_dataset)
        val_num = len(val_dataset)
        test_num = len(test_dataset)

        train_loader = DataLoader(train_dataset,batch_size=BatchSize,
                                                   pin_memory=pin_memory,
                                                   shuffle=True,num_workers=nw)
        val_loader = DataLoader(val_dataset,batch_size=V_size,
                                                   pin_memory=pin_memory,
                                                   shuffle=False,num_workers=nw)
        test_loader = DataLoader(test_dataset,batch_size=T_size,
                                                  pin_memory=pin_memory,
                                                  shuffle=False,num_workers=nw)

        print("using {} images for training, {} images for validation, {} images for testing.".format(train_num,
                                                                                                      val_num,
                                                                                                      test_num))
    else:
        test_dataset = datasets.ImageFolder(root=test_dir,transform=data_transform["test"])
        test_num = len(test_dataset)
        test_loader = DataLoader(test_dataset,batch_size=T_size,
                                                  pin_memory=pin_memory,
                                                  shuffle=False,num_workers=nw)
        print(f"using {test_num} images for testing.")
    
    data_list = test_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in data_list.items())
    n_classes  = len(data_list)

In [10]:

BatchSize = 168
V_size = 21
T_size = 21


train_dir='/kaggle/input/ham10000-augmented/ham10000_augmented/train525e384'
val_dir='/kaggle/input/ham10000-augmented/ham10000_augmented/val525e384png'
test_dir='/kaggle/input/ham10000-augmented/ham10000_augmented/test525png384'
pin_memory = True
nw = 6#min([os.cpu_count(), BatchSize if BatchSize > 1 else 0, 6]) 
print(f'Using {nw} dataloader workers every process.')
get_data()
print(f'Using {n_classes} classes.')

Using 6 dataloader workers every process.
using 51646 images for training, 1006 images for validation, 828 images for testing.
Using 7 classes.




In [20]:
n_channels = 3 #RGB

network = CapsNet(conv_inputs=n_channels, 
                     num_classes=n_classes,# category_number
                     init_weights=True,)
network = network.to(device)

In [21]:
learning_rate = 0.123
optimizer = optim.Adam(network.parameters(), lr=learning_rate)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 5, eta_min=1e-8, last_epoch=-1)

In [17]:
def train(epoch):
    network.train()
    global best_train,train_evl_result#,evl_tmp_result
    running_loss,r_pre = 0., 0.
    print_step = len(train_loader)//2
    steps_num = len(train_loader)
    tmp_size = BatchSize
    print(f'\033[1;32m[Train Epoch:[{epoch}]{img_title} ==> Training]\033[0m ...')
    optimizer.zero_grad()
    train_tmp_result = torch.zeros(n_classes,n_classes)
    startT = timeit.default_timer() 
    
    for batch_idx, (data, target) in enumerate(train_loader):        

        batch_idx += 1
        target_indices = target
        target_one_hot = one_hot(target, length=n_classes)
        data, target = Variable(data).to(device), Variable(target_one_hot).to(device)

        output = network(data)
        loss = network.loss(output, target, size_average=True)       
        loss.backward()     
        optimizer.step()
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
        # v_mag = torch.sqrt(torch.sum(output**2, dim=2, keepdim=True)) 
        # v_mag = torch.norm(output,p=2,dim=(2,3), keepdim=True)
        v_mag = LA.norm(output,ord='nuc',dim=(2,3), keepdim=True)#‘fro’ (default)
        pred = v_mag.data.max(1, keepdim=True)[1].cpu().squeeze()
        r_pre += pred.eq(target_indices.view_as(pred)).squeeze().sum()
        tmp_pre = r_pre/(batch_idx*BatchSize)
        
        if batch_idx % print_step == 0 and batch_idx != steps_num:
            print("[{}/{}] Loss{:.5f},ACC:{:.5f}".format(batch_idx,len(train_loader),
                                                         loss,tmp_pre))
        if batch_idx % steps_num == 0 and train_num % tmp_size != 0:
            tmp_size = train_num % tmp_size
                          
        for i in range(tmp_size):
            pred_x = pred.numpy()
            train_tmp_result[target_indices[i]][pred_x[i]] +=1

        #if best_train < tmp_pre and tmp_pre >= 90: 
        #    torch.save(network.state_dict(), iter_path)
        
    epoch_acc = r_pre / train_num
    epoch_loss = running_loss / len(train_loader)  
    train_loss_list.append(epoch_loss)
    train_acc_list.append(epoch_acc) 
    scheduler.step()
    if best_train < epoch_acc:
        best_train = epoch_acc
        train_evl_result = train_tmp_result.clone()
        #torch.save(network.state_dict(), last_path)
        #torch.save(train_evl_result, f'./tmp/{img_title}/{dirs}/train_evl_result.pth')
    
    endT = timeit.default_timer()
    run_time = endT-startT
    print("Train Epoch:[{}] Running:[{:.2f}s], Loss:{:.5f},Acc:{:.5f},Best_train:{:.5f}".format(epoch,run_time,
                                                                                                 epoch_loss,
                                                                                                 epoch_acc,best_train))

In [18]:
def test(split="test"):
    network.eval()
    global test_acc,eval_acc,best_acc,net_parameters
    global test_evl_result,val_evl_result#,evl_tmp_result
    cor_loss,correct,Auc, Acc= 0, 0, 0, 0
    evl_tmp_result = torch.zeros(n_classes,n_classes)
    
    if split == 'val':
        data_loader = val_loader
        tmp_size = V_size
        data_num = val_num
    else:
        data_loader = test_loader
        tmp_size = T_size
        data_num = test_num
        
    steps_num = len(data_loader)
    print(f'\033[35m{img_title} ==> {split} ...\033[0m')
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            batch_idx +=1
            target_indices = target#torch.Size([batch, 7])  
            target_one_hot = one_hot(target, length=n_classes)            
            data, target = Variable(data).to(device), Variable(target_one_hot).to(device)

            output= network(data)#torch.Size([batch_size, 7, 16, 1])         
            v_mag = LA.norm(output,ord='nuc',dim=(2,3), keepdim=True)#
            pred = v_mag.data.max(1, keepdim=True)[1].cpu()#[9, 2, 1, 1, 6,..., 1, 4, 6, 5, 7,]
            
            if batch_idx % steps_num == 0 and data_num % tmp_size != 0:
                tmp_size = data_num % tmp_size
                          
            for i in range(tmp_size):
                pred_y = pred.numpy()
                evl_tmp_result[target_indices[i]][pred_y[i]] +=1 

        diag_sum = torch.sum(evl_tmp_result.diagonal())
        all_sum = torch.sum(evl_tmp_result) 
        test_acc = 100. * float(torch.div(diag_sum,all_sum)) 
        print(f"{split}_Acc:\033[1;32m{round(float(test_acc),3)}%\033[0m")

        if split == 'val':
            val_acc_list.append(test_acc)
            if test_acc >= best_acc:
                best_acc = test_acc
                val_evl_result = evl_tmp_result.clone()#copy.deepcopy(input)
            print(f"Best_val:\033[1;32m[{round(float(best_acc),3)}%]\033[0m")
        else:
            test_acc_list.append(test_acc)
            if test_acc >= eval_acc:
                eval_acc = test_acc
                test_evl_result = evl_tmp_result.clone()#copy.deepcopy(input)
            print(f"Best_eval:\033[1;32m[{round(float(eval_acc),3)}%]\033[0m")  

In [15]:
num_epochs= 128

In [22]:
for epoch in range(1, num_epochs + 1): 
    train(epoch)
    test('val')
    
print('Finished Training')

[1;32m[Train Epoch:[1]HAM10000 ==> Training][0m ...
[154/308] Loss0.30685,ACC:0.30492
Train Epoch:[1] Running:[457.36s], Loss:0.35631,Acc:0.41637,Best_train:0.41637
[35mHAM10000 ==> val ...[0m
val_Acc:[1;32m74.652%[0m
Best_val:[1;32m[74.652%][0m
[1;32m[Train Epoch:[2]HAM10000 ==> Training][0m ...
[154/308] Loss0.26527,ACC:0.59427
Train Epoch:[2] Running:[454.32s], Loss:0.26089,Acc:0.60857,Best_train:0.60857
[35mHAM10000 ==> val ...[0m
val_Acc:[1;32m68.787%[0m
Best_val:[1;32m[74.652%][0m
[1;32m[Train Epoch:[3]HAM10000 ==> Training][0m ...
[154/308] Loss0.28835,ACC:0.66068
Train Epoch:[3] Running:[454.41s], Loss:0.22544,Acc:0.66789,Best_train:0.66789
[35mHAM10000 ==> val ...[0m
val_Acc:[1;32m81.71%[0m
Best_val:[1;32m[81.71%][0m
[1;32m[Train Epoch:[4]HAM10000 ==> Training][0m ...
[154/308] Loss0.20080,ACC:0.70485
Train Epoch:[4] Running:[455.09s], Loss:0.20069,Acc:0.70941,Best_train:0.70941
[35mHAM10000 ==> val ...[0m
val_Acc:[1;32m84.791%[0m
Best_val:[1;32m

KeyboardInterrupt: 

In [24]:
torch.save(network.state_dict(), 'gpu_78_epochs_capsnet.pth')