In [None]:
import torch
import sys, os
import json
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F
import torch.optim as optim

import prettytable
import time
sys.setrecursionlimit(15000)
from thop.profile import profile

from PIL import Image
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
from tqdm import tqdm
import seaborn as sns

from utils import ImageShow,draw_size_acc,one_hot
from utils import confusion_matrix,metrics_scores
from model import FixCapsNet

def get_data(trans_size='308'):
    global test_dataset,train_loader,val_loader,test_loader,train_num,val_num,test_num,n_classes,cla_dict
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop((299, 299)),
                                     transforms.RandomVerticalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((308, 308)),
                                   transforms.CenterCrop((299, 299)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                  ]),
        "test": transforms.Compose([transforms.Resize((trans_size, trans_size)),
                                   transforms.CenterCrop((299, 299)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                  ])
        }

    data_root = os.path.abspath(os.path.join(os.getcwd()))  # get data root path
    # image_path = os.path.join(data_root, "datasets","HAM10000")#
    # assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    train_dataset = datasets.ImageFolder(root='D:/ACSII_proyecto/FixCaps-main/augmentation/train525s8',#direccion de train
                                         transform=data_transform["train"])
    val_dataset = datasets.ImageFolder(root='D:/ACSII_proyecto/FixCaps-main/augmentation/val525s8',##direccion de val
                                            transform=data_transform["val"])
    test_dataset = datasets.ImageFolder(root='D:/ACSII_proyecto/FixCaps-main/augmentation/test_dir',#direccion de test
                                            transform=data_transform["test"])

    train_num = len(train_dataset)
    val_num = len(val_dataset)
    test_num = len(test_dataset)

    data_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in data_list.items())
    n_classes  = len(data_list)
    print(f'Using {n_classes } classes.')
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open(f'{img_title}.json', 'w') as json_file:#class_indices
        json_file.write(json_str)

    pin_memory = True
    train_loader = DataLoader(train_dataset,batch_size=BatchSize,
                                               pin_memory=pin_memory,
                                               shuffle=True,num_workers=nw)
    val_loader = DataLoader(val_dataset,batch_size=V_size,
                                               pin_memory=pin_memory,
                                               shuffle=False,num_workers=nw)
    test_loader = DataLoader(test_dataset,batch_size=T_size,
                                              pin_memory=pin_memory,
                                              shuffle=False,num_workers=nw)

    print("using {} images for training, {} images for validation, {} images for testing.".format(train_num,
                                                                                                  val_num,
                                                                                                  test_num))
def train(epoch):
    network.train()
    global best_train,train_evl_result#,evl_tmp_result
    running_loss,r_pre = 0., 0.
    print_step = len(train_loader)//2
    steps_num = len(train_loader)
    tmp_size = BatchSize
    print(f'\033[1;32m[Train Epoch:[{epoch}]{img_title} ==> Training]\033[0m ...')
    optimizer.zero_grad()
    train_tmp_result = torch.zeros(n_classes,n_classes)
    # scaler = torch.cuda.amp.GradScaler()

    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):

        batch_idx += 1
        target_indices = target
        target_one_hot = one_hot(target, length=n_classes)
        data, target = Variable(data).to(device), Variable(target_one_hot).to(device)
        # data, target =data.to(device), target_one_hot.to(device)
        # with torch.cuda.amp.autocast():
        output = network(data)
        loss = network.loss(output, target, size_average=True)
        loss.backward()
        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()
        optimizer.step()
        optimizer.zero_grad()
        # scheduler.step()#AdamW
        running_loss += loss.item()

        v_mag = torch.sqrt(torch.sum(output**2, dim=2, keepdim=True))
        pred = v_mag.data.max(1, keepdim=True)[1].cpu().squeeze()
        r_pre += pred.eq(target_indices.view_as(pred)).squeeze().sum()
        tmp_pre = r_pre/(batch_idx*BatchSize)

        if batch_idx % print_step == 0:
            print("[{}/{}] Loss{:.5f},ACC:{:.5f}".format(batch_idx,len(train_loader),
                                                         loss,tmp_pre))
        if batch_idx % steps_num == 0 and train_num % tmp_size != 0:
            tmp_size = train_num % tmp_size

        for i in range(tmp_size):
            pred_x = pred.numpy()
            train_tmp_result[target_indices[i]][pred_x[i]] +=1

        if best_train < tmp_pre and tmp_pre >= 80:
            torch.save(network.state_dict(), iter_path)

    epoch_acc = r_pre / train_num
    epoch_loss = running_loss / len(train_loader)
    train_loss_list.append(epoch_loss)
    train_acc_list.append(epoch_acc)
    scheduler.step()
    if best_train < epoch_acc:
        best_train = epoch_acc
        train_evl_result = train_tmp_result.clone()
        torch.save(network.state_dict(), last_path)
        torch.save(train_evl_result, f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/train_evl_result.pth')

    print("Train Epoch:[{}] Loss:{:.5f},Acc:{:.5f},Best_train:{:.5f}".format(epoch,epoch_loss,
                                                                     epoch_acc,best_train))

def test(split="test"):
    network.eval()
    global test_acc,eval_acc,best_acc,test_evl_result,val_evl_result,evl_tmp_result,net_parameters
    cor_loss,correct,Auc, Acc= 0, 0, 0, 0
    evl_tmp_result = torch.zeros(n_classes,n_classes)

    if split == 'val':
        data_loader = val_loader
        tmp_size = V_size
        data_num = val_num
    else:
        data_loader = test_loader
        tmp_size = T_size
        data_num = test_num

    steps_num = len(data_loader)
    print(f'\033[35m{img_title} ==> {split} ...\033[0m')

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(data_loader)):
            batch_idx +=1
            target_indices = target#torch.Size([batch, 7])
            target_one_hot = one_hot(target, length=n_classes)
            data, target = Variable(data).to(device), Variable(target_one_hot).to(device)

            output= network(data)#torch.Size([batch_size, 7, 16, 1])
            v_mag = torch.sqrt(torch.sum(output**2, dim=2, keepdim=True))
            pred = v_mag.data.max(1, keepdim=True)[1].cpu()#[9, 2, 1, 1, 6,..., 1, 4, 6, 5, 7,]

            if batch_idx % steps_num == 0 and test_num % tmp_size != 0:
                tmp_size = data_num % tmp_size

            for i in range(tmp_size):
                pred_y = pred.numpy()
                evl_tmp_result[target_indices[i]][pred_y[i]] +=1

    diag_sum = torch.sum(evl_tmp_result.diagonal())
    all_sum = torch.sum(evl_tmp_result)
    test_acc = 100. * float(torch.div(diag_sum,all_sum))
    print(f"{split}_Acc:\033[1;32m{round(float(test_acc),3)}%\033[0m")

    if split == 'val':
        val_acc_list.append(test_acc)
        if test_acc > best_acc:
            best_acc = test_acc
            val_evl_result = evl_tmp_result.clone()#copy.deepcopy(input)
            torch.save(network.state_dict(), save_PATH)
            torch.save(val_evl_result, f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/best_evl_result.pth')
        print(f"Best_val:\033[1;32m[{round(float(best_acc),3)}%]\033[0m")
    else:
        test_acc_list.append(test_acc)
        if test_acc > eval_acc:
            eval_acc = test_acc
            test_evl_result = evl_tmp_result.clone()#copy.deepcopy(input)
            torch.save(network.state_dict(), f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{split}_best_{img_title}_{suf}.pth')
            torch.save(test_evl_result, f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{split}_evl_result.pth')
        print(f"Best_eval:\033[1;32m[{round(float(eval_acc),3)}%]\033[0m")
if __name__ == '__main__':
    sys.path.append(os.pardir)
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
    img_title = "HAM10000"#"COVID-19"#"ISIC2019"#"skin_lesion"#
    best_acc = 0.
    eval_acc = 0.
    best_train = 0.
    #defined
    try:
        print(len(train_acc_list))
    except NameError:
        train_loss_list = []
        train_acc_list = []
        test_loss_list = []
        test_acc_list = []
        test_auc_list = []
        val_loss_list = []
        val_acc_list = []
    #activate ImageShow
    show = ImageShow(train_loss_list = train_loss_list,
                    train_acc_list = train_acc_list,
                    test_loss_list = test_loss_list,
                    test_acc_list = test_acc_list,
                    test_auc_list = test_auc_list,
                    val_loss_list = val_loss_list,
                    val_acc_list = val_acc_list,
                    )

    BatchSize = 128#128#188
    V_size = 31
    T_size = 31
    learning_rate = 0.123
    train_doc = "train525e384png"
    val_doc = "val525e384png"
    test_doc = "test525png384"
    nw = min([os.cpu_count(), BatchSize if BatchSize > 1 else 0, 4])
    print(f'Using {nw} dataloader workers every process.')
    get_data()

    # Create capsule network.
    n_channels = 3
    conv_outputs = 128 #Feature_map
    num_primary_units = 8
    primary_unit_size = 16 * 6 * 6  # fixme get from conv2d
    output_unit_size = 16
    img_size = 299
    mode='DS'
    network = FixCapsNet(conv_inputs=n_channels,
                        conv_outputs=conv_outputs,
                        primary_units=num_primary_units,
                        primary_unit_size=primary_unit_size,
                        num_classes=n_classes,
                        output_unit_size=16,
                        init_weights=True,
                        mode=mode)
    network = network.to(device)
    summary(network,(n_channels,img_size,img_size))


    print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
    print("---|---|---")
    name = "FixCaps"
    dsize = (1, 3, 299, 299)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(network, (inputs,), verbose=False)
    print(
        "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
        )
    #FLOPs(G)--> 0.07(0.08).

    network.Convolution



    try:
        print(f"suf:{suf}")
    except NameError:
        suf = time.strftime("%m%d_%H%M%S", time.localtime())
        print(f"suf:{suf}")
    if os.path.exists(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}'):
        print (f'Store: "D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}"')
    else:
        os.mkdir(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}')
    iter_path = f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/train_{img_title}_{suf}.pth'
    save_PATH = f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/best_{img_title}_{suf}.pth'
    last_path = f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/last_{img_title}_{suf}.pth'
    print(save_PATH)


    num_epochs = 125


    # learning_rate = 0.123
    def_betas=(0.9, 0.999)
    optimizer = optim.Adam(network.parameters(), lr=learning_rate)
    # optimizer = optim.AdamW(network.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 5, eta_min=1e-8, last_epoch=-1)

    #base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=10000.0,
    # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate,
    #                         steps_per_epoch=len(train_loader),
    #                         epochs=num_epochs)

    for epoch in range(1, num_epochs + 1):
        train(epoch)
        test('val')

    print('Finished Training')


    network.load_state_dict(torch.load(save_PATH))
    dict_size ={}
    dict_test = {}
    for j in range(21,31):
        print(f"size:{j}")
        T_size = j

        for i in range(300,325):
            get_data(i)
            for k in range(5):
                test()
                if dict_test.get(i) is None or dict_test[i] < test_acc:
                    dict_test[i] = test_acc

                    if dict_size.get(j) is None or dict_size[j] < test_acc:
                        dict_size[j] = test_acc

                elif dict_size.get(j) is None or dict_size[j] < test_acc:
                        dict_size[j] = test_acc

    show.conclusion(img_title=img_title)
    sorted(dict_size.items(), key=lambda x: x[1], reverse=True)[0:9]


    sorted(dict_test.items(), key=lambda x: x[1], reverse=True)[0:9]


    # draw_size_acc(dict_test,custom_path='./tmp',img_title=img_title,suf=suf)

    metrics_scores(test_evl_result,n_classes,cla_dict)


    # #save
    s0 = np.array(train_acc_list)
    np.save(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{img_title}_train_acc_{suf}.npy', s0)
    s1 = np.array(train_loss_list)
    np.save(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{img_title}_train_loss_{suf}.npy', s1)
    s2 = np.array(val_acc_list)
    np.save(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{img_title}_val_acc_{suf}.npy', s2)
    s3 = np.array(test_acc_list)
    np.save(f'D:/ACSII_proyecto/FixCaps-main/{img_title}/{suf}/{img_title}_test_acc_{suf}.npy', s2)


    torch.cuda.memory.empty_cache()

    metrics_scores(test_evl_result,n_classes,cla_dict)

0
Using 4 dataloader workers every process.
Using 7 classes.
using 48322 images for training, 5970 images for validation, 828 images for testing.
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 3, 141, 141]             975
            Conv2d-2          [-1, 128, 71, 71]             512
       BatchNorm2d-3          [-1, 128, 71, 71]             256
              ReLU-4          [-1, 128, 71, 71]               0
FractionalMaxPool2d-5          [-1, 128, 20, 20]               0
            Conv2d-6          [-1, 128, 20, 20]          16,384
       BatchNorm2d-7          [-1, 128, 20, 20]             256
         Hardswish-8          [-1, 128, 20, 20]               0
 AdaptiveAvgPool2d-9            [-1, 128, 1, 1]               0
           Conv2d-10            [-1, 128, 1, 1]          16,384
             ReLU-11            [-1, 128, 1, 1]               0
           Conv2d-12

 50%|█████     | 189/378 [20:35<19:12,  6.10s/it]

[189/378] Loss0.36791,ACC:0.28799


100%|██████████| 378/378 [40:07<00:00,  5.22s/it]

[378/378] Loss0.36450,ACC:0.35489


100%|██████████| 378/378 [40:08<00:00,  6.37s/it]


Train Epoch:[1] Loss:0.40855,Acc:0.35535,Best_train:0.35535
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:46<00:00,  1.82it/s]


val_Acc:[1;32m49.129%[0m
Best_val:[1;32m[49.129%][0m
[1;32m[Train Epoch:[2]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [19:16<19:16,  6.12s/it]

[189/378] Loss0.36772,ACC:0.44585


100%|██████████| 378/378 [38:36<00:00,  5.03s/it]

[378/378] Loss0.31728,ACC:0.45974


100%|██████████| 378/378 [38:37<00:00,  6.13s/it]


Train Epoch:[2] Loss:0.34895,Acc:0.46033,Best_train:0.46033
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:31<00:00,  2.11it/s]


val_Acc:[1;32m60.218%[0m
Best_val:[1;32m[60.218%][0m
[1;32m[Train Epoch:[3]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [19:50<19:37,  6.23s/it]

[189/378] Loss0.31724,ACC:0.49111


100%|██████████| 378/378 [38:20<00:00,  4.87s/it]

[378/378] Loss0.33644,ACC:0.49733


100%|██████████| 378/378 [38:22<00:00,  6.09s/it]


Train Epoch:[3] Loss:0.32626,Acc:0.49797,Best_train:0.49797
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:30<00:00,  2.13it/s]


val_Acc:[1;32m65.946%[0m
Best_val:[1;32m[65.946%][0m
[1;32m[Train Epoch:[4]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [18:26<17:46,  5.64s/it]

[189/378] Loss0.35469,ACC:0.51649


100%|██████████| 378/378 [36:48<00:00,  4.84s/it]

[378/378] Loss0.31603,ACC:0.52350


100%|██████████| 378/378 [36:50<00:00,  5.85s/it]


Train Epoch:[4] Loss:0.31225,Acc:0.52417,Best_train:0.52417
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [02:15<00:00,  1.42it/s]


val_Acc:[1;32m68.325%[0m
Best_val:[1;32m[68.325%][0m
[1;32m[Train Epoch:[5]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [19:31<18:21,  5.83s/it]

[189/378] Loss0.28873,ACC:0.53977


100%|██████████| 378/378 [37:49<00:00,  4.75s/it]

[378/378] Loss0.29799,ACC:0.54076


100%|██████████| 378/378 [37:50<00:00,  6.01s/it]


Train Epoch:[5] Loss:0.30128,Acc:0.54145,Best_train:0.54145
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:34<00:00,  2.03it/s]


val_Acc:[1;32m68.074%[0m
Best_val:[1;32m[68.325%][0m
[1;32m[Train Epoch:[6]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [18:41<19:23,  6.16s/it]

[189/378] Loss0.30869,ACC:0.55179


100%|██████████| 378/378 [39:06<00:00,  4.82s/it]

[378/378] Loss0.28152,ACC:0.54727


100%|██████████| 378/378 [39:07<00:00,  6.21s/it]


Train Epoch:[6] Loss:0.29805,Acc:0.54797,Best_train:0.54797
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:37<00:00,  1.98it/s]


val_Acc:[1;32m68.342%[0m
Best_val:[1;32m[68.342%][0m
[1;32m[Train Epoch:[7]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [21:26<26:13,  8.33s/it]

[189/378] Loss0.27012,ACC:0.54229


100%|██████████| 378/378 [45:47<00:00,  4.90s/it]

[378/378] Loss0.31724,ACC:0.54861


100%|██████████| 378/378 [45:49<00:00,  7.27s/it]


Train Epoch:[7] Loss:0.29663,Acc:0.54932,Best_train:0.54932
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:43<00:00,  1.87it/s]


val_Acc:[1;32m69.615%[0m
Best_val:[1;32m[69.615%][0m
[1;32m[Train Epoch:[8]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [28:28<1:59:06, 37.81s/it] 

[189/378] Loss0.31672,ACC:0.54923


100%|██████████| 378/378 [54:55<00:00,  7.71s/it]  

[378/378] Loss0.33027,ACC:0.55068


100%|██████████| 378/378 [54:57<00:00,  8.72s/it]


Train Epoch:[8] Loss:0.29646,Acc:0.55138,Best_train:0.55138
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:53<00:00,  1.71it/s]


val_Acc:[1;32m71.457%[0m
Best_val:[1;32m[71.457%][0m
[1;32m[Train Epoch:[9]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [27:19<26:40,  8.47s/it]

[189/378] Loss0.28970,ACC:0.55576


100%|██████████| 378/378 [46:33<00:00,  4.62s/it]

[378/378] Loss0.27276,ACC:0.55975


100%|██████████| 378/378 [46:34<00:00,  7.39s/it]


Train Epoch:[9] Loss:0.29358,Acc:0.56047,Best_train:0.56047
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:39<00:00,  1.95it/s]


val_Acc:[1;32m70.921%[0m
Best_val:[1;32m[71.457%][0m
[1;32m[Train Epoch:[10]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [20:24<25:53,  8.22s/it]

[189/378] Loss0.34139,ACC:0.56448


100%|██████████| 378/378 [46:40<00:00,  7.55s/it]

[378/378] Loss0.31465,ACC:0.56802


100%|██████████| 378/378 [46:42<00:00,  7.41s/it]


Train Epoch:[10] Loss:0.28865,Acc:0.56875,Best_train:0.56875
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:53<00:00,  1.70it/s]


val_Acc:[1;32m74.439%[0m
Best_val:[1;32m[74.439%][0m
[1;32m[Train Epoch:[11]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [27:16<26:58,  8.57s/it]

[189/378] Loss0.29770,ACC:0.58156


100%|██████████| 378/378 [53:43<00:00,  7.21s/it]

[378/378] Loss0.30516,ACC:0.58468


100%|██████████| 378/378 [53:45<00:00,  8.53s/it]


Train Epoch:[11] Loss:0.27998,Acc:0.58543,Best_train:0.58543
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [02:02<00:00,  1.58it/s]


val_Acc:[1;32m72.982%[0m
Best_val:[1;32m[74.439%][0m
[1;32m[Train Epoch:[12]HAM10000 ==> Training][0m ...


 50%|█████     | 189/378 [26:55<24:48,  7.88s/it]

[189/378] Loss0.28020,ACC:0.59619


100%|██████████| 378/378 [51:38<00:00,  5.76s/it]

[378/378] Loss0.27183,ACC:0.59981


100%|██████████| 378/378 [51:42<00:00,  8.21s/it]


Train Epoch:[12] Loss:0.27209,Acc:0.60058,Best_train:0.60058
[35mHAM10000 ==> val ...[0m


100%|██████████| 193/193 [01:52<00:00,  1.71it/s]


val_Acc:[1;32m68.141%[0m
Best_val:[1;32m[74.439%][0m
[1;32m[Train Epoch:[13]HAM10000 ==> Training][0m ...


 42%|████▏     | 157/378 [17:22<24:27,  6.64s/it]


KeyboardInterrupt: 