In [1]:
!pip install pretrainedmodels
!pip install SimpleITK
!pip install tensorboardX

Collecting pretrainedmodels
  Using cached pretrainedmodels-0.7.4-py3-none-any.whl
Collecting munch (from pretrainedmodels)
  Using cached munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Using cached munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Installing collected packages: munch, pretrainedmodels
Successfully installed munch-4.0.0 pretrainedmodels-0.7.4
Collecting SimpleITK
  Using cached SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Using cached SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
Installing collected packages: SimpleITK
Successfully installed SimpleITK-2.3.1
Collecting tensorboardX
  Using cached tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Using cached tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [9]:
import os
from os.path import join
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pretrainedmodels
import os
from os.path import join
from PIL import Image
import torchvision
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import ImageFolder
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import SimpleITK as sitk
import sys
import time
import functools
from os.path import join, exists
from datetime import datetime
import logging
import logging.handlers
import torch
import ssl
from tensorboardX import SummaryWriter

pblog = None


def get_pblog(*args, **kwargs):
    global pblog
    if pblog is None:
        pblog = ProgressBarLog(*args, **kwargs)
    return pblog

def log_decorator(call_fn):
    @functools.wraps(call_fn)
    def log(self, *args, move=False):
        if self.current < self.total:
            sys.stdout.write(' ' * (self.width + 26) + '\r')
            sys.stdout.flush()
        call_fn(self, *args)
        if move:
            self._current += 1
            temp = datetime.now()
            delta = temp - self.last_time
            self.last_time = temp
            temp = temp + delta * (self.total - self.current)
            self.ok_time = str(temp).split('.')[0]
        if self.current < self.total:
            progress = int(self.width * self.current / self.total)
            temp = '{:2}%][{}]\r'.format(int(100 * self.current / self.total),
                                         self.ok_time)
            sys.stdout.write('[' + '=' * progress + '>' + '-' * (
                    self.width - progress - 1) + temp)
            sys.stdout.flush()

    return log


class ProgressBarLog:
    def __init__(self, total=50, width=76, current=0, logger=None):
        self.width = width - 26
        self.total = total
        self._current = current
        if logger is None:
            log_path = join(config['log_dir'], config['cmd'])
            desc = '{}_{}_{}_{}'. \
                format(config['dataset'], config['model'], config['action'], config['desc'])
            self.logger = gen_logger(gen_t_name(log_path, desc, '.log'))
        else:
            self.logger = logger
        self.last_time = datetime.now()
        self.ok_time = None
        self.pb_last_time = time.time()
        self.pb_begin_time = self.pb_last_time

    @property
    def current(self):
        return self._current

    @current.setter
    def current(self, value):
        if not isinstance(value, int):
            raise ValueError
        if value < 0 or value > self.total:
            raise ValueError
        self._current = value

    @log_decorator
    def debug(self, msg):
        self.logger.debug(msg)

    @log_decorator
    def info(self, msg):
        self.logger.info(msg)

    @log_decorator
    def warning(self, msg):
        self.logger.warning(msg)

    @log_decorator
    def error(self, msg):
        self.logger.error(msg)

    @log_decorator
    def exception(self, msg):
        self.logger.exception(msg)

    @log_decorator
    def print(self, *args):
        print(*args)

    @log_decorator
    def refresh(self):
        pass

    def pb(self, current, total, msg):
        TOTAL_BAR_LENGTH = 45.
        # _, term_width = os.popen('stty size', 'r').read().split()
        term_width = 94
        if current == 0:
            self.pb_begin_time = time.time()  # Reset for new bar.

        cur_len = int(TOTAL_BAR_LENGTH * current / total)
        rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

        sys.stdout.write(' [')
        for _ in range(cur_len):
            sys.stdout.write('=')
        sys.stdout.write('>')
        for _ in range(rest_len):
            sys.stdout.write('.')
        sys.stdout.write(']')

        cur_time = time.time()
        self.pb_last_time = cur_time
        tot_time = cur_time - self.pb_begin_time

        L = []
        L.append('Tot: %s' % format_time(tot_time))
        if msg:
            L.append(' | ' + msg)

        msg = ''.join(L)
        sys.stdout.write(msg)
        for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
            sys.stdout.write(' ')

        # Go back to the center of the bar.
        for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
            sys.stdout.write('\b')
        sys.stdout.write(' %d/%d ' % (current + 1, total))

        if current < total - 1:
            sys.stdout.write('\r')
        else:
            sys.stdout.write('\n')
        sys.stdout.flush()


def gen_logger(file_path, log_name=None):
    if log_name is None:
        log_name = config['log_name']
    cmd_fmt = '[%(asctime)s] @%(name)s %(levelname)-8s%(message)s'
    cmd_datefmt = '%Y-%m-%d %H:%M:%S'
    formatter = ColoredFormatter(cmd_fmt, cmd_datefmt)
    file_handler = logging.FileHandler(file_path)
    file_handler.formatter = formatter
    console_handler = logging.StreamHandler()
    console_handler.formatter = formatter
    logger = logging.getLogger(log_name)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    logger.setLevel(logging.DEBUG)

    return logger


def gen_t_name(base_dir, desc, ext):
    check_mkdir(base_dir)
    while True:
        dt = datetime.now()
        temp = join(base_dir, desc + dt.strftime('_%y%m%d_%H%M%S') + ext)
        if exists(temp):
            time.sleep(0.000001)
        else:
            break
    return temp


class ColoredFormatter(logging.Formatter):
    '''A colorful formatter.'''

    def __init__(self, fmt=None, datefmt=None):
        logging.Formatter.__init__(self, fmt, datefmt)

    def format(self, record):
        # Color escape string
        COLOR_RED = '\033[1;31m'
        COLOR_GREEN = '\033[1;32m'
        COLOR_YELLOW = '\033[1;33m'
        COLOR_BLUE = '\033[1;34m'
        COLOR_PURPLE = '\033[1;35m'
        COLOR_CYAN = '\033[1;36m'
        COLOR_GRAY = '\033[1;37m'
        COLOR_WHITE = '\033[1;38m'
        COLOR_RESET = '\033[1;0m'
        # Define log color
        LOG_COLORS = {
            'DEBUG': COLOR_BLUE + '%s' + COLOR_RESET,
            'INFO': COLOR_GREEN + '%s' + COLOR_RESET,
            'WARNING': COLOR_YELLOW + '%s' + COLOR_RESET,
            'ERROR': COLOR_RED + '%s' + COLOR_RESET,
            'CRITICAL': COLOR_RED + '%s' + COLOR_RESET,
            'EXCEPTION': COLOR_RED + '%s' + COLOR_RESET,
        }
        level_name = record.levelname
        msg = logging.Formatter.format(self, record)
        return LOG_COLORS.get(level_name, '%s') % msg


def format_time(seconds):
    days = int(seconds / 3600 / 24)
    seconds = seconds - days * 3600 * 24
    hours = int(seconds / 3600)
    seconds = seconds - hours * 3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes * 60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds * 1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f

In [10]:
class Se_resnet(nn.Module):
    def __init__(self, num_classes=1000):
        super(Se_resnet, self).__init__()
        self.features = pretrainedmodels.se_resnet152(num_classes=1000, pretrained='imagenet')
        self.features.last_linear = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.features(x)
        return x

In [11]:
class GallbladderDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.frame = pd.read_csv(csv_file, encoding='utf-8', header=None)
        self.root_dir = root_dir
        # print('csv_file source----->', csv_file)
        # print('root_dir source----->', root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.frame.iloc[idx, 0])
        img_name = os.path.basename(img_path)
        # print(img_name)
        _, extension = os.path.splitext(self.frame.iloc[idx, 0])
        # print(extension)
        image = self.image_loader(img_path, extension)
        # print(image)
        label = int(self.frame.iloc[idx, 1])
        if self.transform is not None:
            image = self.transform(image)
        sample = {'image': image, 'label': label, 'img_name': img_name}
        return sample

    def image_loader(self, img_name, extension):
        if extension == '.JPG':
            # print('读取jpg')
            return self.read_jpg(img_name)
        elif extension == '.jpg':
            # print('读取jpg')
            return self.read_jpg(img_name)
        elif extension == '.DCM':
            # print('读取dcm')
            return self.read_dcm(img_name)
        elif extension == '.dcm':
            # print('读取dcm')
            return self.read_dcm(img_name)
        elif extension == '.Bmp':
            # print('读取Bmp')
            return self.read_bmp(img_name)
        elif extension == '.png':
            return self.read_png(img_name)

    def read_jpg(self, img_name):
        return Image.open(img_name).convert('RGB')

    def read_dcm(self, img_name):
        ds = sitk.ReadImage(img_name)
        img_array = sitk.GetArrayFromImage(ds)
        img_bitmap = Image.fromarray(img_array[0])
        return img_bitmap

    def read_bmp(self, img_name):
        return Image.open(img_name)

    def read_png(self, img_name):
        return Image.open(img_name)


def call_gallbladder_dataset():
    tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    dataset = GallbladderDataset(
        csv_file='training_data/label.csv',
        root_dir="training_data",
        transform=tf
    )
    dataloader = DataLoader(dataset=dataset, batch_size=int(len(dataset) * 0.1), shuffle=False, num_workers=2)
    for item in dataloader:
        images = item['image']
        images = images.numpy()
        mean = np.mean(images, axis=(0, 2, 3))
        std = np.std(images, axis=(0, 2, 3))
        break
    print(mean, std)
class BaseLoader:
    def __init__(self, args):
        # special params
        self.num_workers = 1
        if 'batch_size' in args:
            self.batch_size = 16

        # custom properties
        self._dataset_train = None
        self._dataset_eval = None
        self._dataset_test = None
        self._dataset_norm = None
        self._dataset_attack = None
        self.dataloader_train = None
        self.dataloader_eval = None
        self.dataloader_test = None

    @property
    def dataset_train(self):
        raise NotImplementedError

    @property
    def dataset_eval(self):
        raise NotImplementedError

    @property
    def dataset_test(self):
        raise NotImplementedError

    @property
    def dataset_norm(self):
        raise NotImplementedError

    @property
    def train(self):
        if self.dataloader_train is None:
            self.dataloader_train = Data.DataLoader(self.dataset_train,
                                                    batch_size=self.batch_size,
                                                    shuffle=True,
                                                    num_workers=self.num_workers,
                                                    pin_memory=True)
        return self.dataloader_train

    @property
    def eval(self):
        if self.dataloader_eval is None:
            self.dataloader_eval = Data.DataLoader(self.dataset_eval,
                                                   batch_size=self.batch_size,
                                                   num_workers=self.num_workers,
                                                   pin_memory=True)
        return self.dataloader_eval

    @property
    def test(self):
        if self.dataloader_test is None:
            self.dataloader_test = Data.DataLoader(self.dataset_test,
                                                   batch_size=self.batch_size,
                                                   num_workers=self.num_workers,
                                                   pin_memory=True)
        return self.dataloader_test

    def cal_norm(self, n=1):
        return None

    @staticmethod
    def random_sample_base(base_dir, transform, size):
        classes = os.listdir(base_dir)
        classes.sort()
        images = []
        for c in classes:
            folder = join(base_dir, c)
            for file_name in np.random.choice(os.listdir(folder), size, False):
                img = join(folder, file_name)
                images.append(transform(Image.open(img)))
        return classes, torch.stack(images)



def get_dataloader(args):
    if args['dataset'] == 'Gallbladder':
        return Gallbladder(args)
    else:
        raise ValueError('No dataset: {}'.format(args['dataset']))


class Gallbladder(BaseLoader):
    # contain empty
    mean = [0.359, 0.361, 0.379]
    std = [0.190, 0.190, 0.199]
    # except empty
    # mean = [0.359, 0.361, 0.380]
    # std = [0.191, 0.190, 0.200]

    num_classes = 2
    class_names = ('Biliary atresia', 'Non-biliary atresia')

    def __init__(self, args):
        super(Gallbladder, self).__init__(args)
        self.train_dir = 'training_data'
        self.test_dir = 'test_data'
        self.train_csv = 'training_data/label.csv'
        self.test_csv = 'test_data/label.csv'
        self.mean = Gallbladder.mean
        self.std = Gallbladder.std
        self.img_size = args['img_size']
        self.convert = transforms.Grayscale(3)

    @property
    def dataset_train(self):
        if self._dataset_train is None:
            tf = transforms.Compose([
                transforms.RandomResizedCrop(self.img_size, scale=(0.8, 1)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std)
            ])
            self._dataset_train = GallbladderDataset(csv_file= self.train_csv,
                                                                         root_dir=self.train_dir,
                                                                         transform=tf)
        return self._dataset_train

    @property
    def dataset_eval(self):
        if self._dataset_eval is None:
            tf = transforms.Compose([
                transforms.Resize((self.img_size, self.img_size)),
                self.gray_to_rgb,
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std)
            ])
            self._dataset_eval = GallbladderDataset(csv_file=self.test_csv,
                                                                        root_dir=self.test_dir,
                                                                        transform=tf)
        return self._dataset_eval

    def gray_to_rgb(self, img):
        size = np.array(img).transpose().shape
        # print(size)
        if size[0] != 3:
            img = self.convert(img)

        return img



def call_dataloader():
    args = {'dataset': 'Gallbladder',
            'Gallbladder_train_dir': '/home/cccc/Desktop/share/deeplearning_project/pnasnet/data/DataSets/final_train',
            'Gallbladder_test_dir': '/home/cccc/Desktop/share/deeplearning_project/pnasnet/data/DataSets/final_train',
            'train_csv': './label.csv',
            'test_csv': './label.csv',
            'num_workers': 4,
            'batch_size': 16,
            'img_size': 331}  # pnasnet input size=331
    dl = get_dataloader(args).train
    print(len(dl))


In [12]:
class BaseAction:
    loss_legend = ['| loss: {:0<10.8f}']
    eval_on_train = True
    eval_legend = ['| acc: {:0<5.3f}%']
    eval_personal_legend = ['| sensitivity: {:0<5.3f}%', '| specitivity: {:0<5.3f}%']

    @staticmethod
    def cal_logits(x, net):
        return net(x)

    @staticmethod
    def cal_loss(y, y_hat, weight):
        #weigth = torch.tensor([weight]).cuda()
        weigth = torch.tensor([weight])
        loss = F.cross_entropy(y_hat, y, weight=weigth.reshape((2,)))
        return loss,

    @staticmethod
    def cal_eval(y, y_hat):
        count_right = np.empty(1, np.float32)
        count_sum = np.empty(1, np.float32)
        y_hat = y_hat.argmax(1)
        count_right[0] = (y_hat == y).sum().item()
        count_sum[0] = y.size(0)
        return 100 * count_right, count_sum

    @staticmethod
    def update_opt(epoch, net, opt_type, lr=1e-2, lr_epoch=35):
        decay = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
        if epoch % lr_epoch == 0:
            times = int(epoch / lr_epoch)
            times = len(decay) - 1 if times >= len(decay) else times
            if opt_type == 'sgd':
                return torch.optim.SGD(net.parameters(), lr=lr * decay[times],
                                       momentum=0.9,
                                       weight_decay=5e-4)
            elif opt_type == 'adam':
                return torch.optim.Adam(net.parameters(), lr=lr * decay[times])
        else:
            return None

    @staticmethod
    def save_model(ism, model, path, *args):
        ''''
        ism, model, path, msg, pblog, acc, epoch
        '''
        acc, epoch = args
        if ism:
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        state = {
            'net': state_dict,
            'acc': acc,
            'epoch': epoch}
        torch.save(state, path)

    @staticmethod
    def save_graph(model, img_size, tblog, pblog):
        #dummyInput = torch.randn([1, 3, img_size, img_size]).cuda()
        dummyInput = torch.randn([1, 3, img_size, img_size])
        tblog.add_graph(model, dummyInput)
        pblog.debug('Graph saved')

    @staticmethod
    def cal_scalars(metric, metric_legend, msg, pblog):
        scalars = dict()
        for n, s in zip(metric, metric_legend):
            msg += s.format(n)
            scalars[s.split(':')[0][2:]] = n
        pblog.info(msg)
        return scalars

    # @staticmethod
    # def log_confusion_matrix(labels, predictions, class_names):
    #     cm = confusion_matrix(labels, predictions)
    #     # normalise confusion matrix for diff sized groups
    #     cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
    #     cm_image = plot_confusion_matrix(cm_norm, class_names)
    #     return cm_image


def get_action(args):
    action = args['action']
    if action == 'base':
        return BaseAction()
    else:
        raise ValueError('No action: {}'.format(action))
def check_mkdir(dir_name):
    if not exists(dir_name):
        os.makedirs(dir_name)

In [13]:
class Trainer:

    def __init__(self, args):
        # cuda setting
        #os.environ["CUDA_VISIBLE_DEVICES"] = args['cuda']

        # dir setting
        self.model_dir = args['model_dir']
        self.best_model_dir = args['best_model_dir']
        self.step_model_dir = args['step_model_dir']
        check_mkdir(self.model_dir)
        check_mkdir(self.best_model_dir)

        # dataset setting
        self.dataloader = get_dataloader(args)
        self.no_eval = args['no_eval']
        self.personal_eval = args['personal_eval']
        self.img_size = args['img_size']
        args['mean'] = self.dataloader.mean
        args['std'] = self.dataloader.std
        args['num_classes'] = self.dataloader.num_classes

        # basic setting
        self.opt_type = args['optimizer']
        self.lr = args['lr']
        self.lr_epoch = args['lr_epoch']
        self.epoch = args['epoch']
        self.weight = args['weight']
        self.eval_best = 0
        self.eval_best_epoch = 0
        self.save_cm = args['save_cm']  # save confusion matrix

        # model name config
        self.model_desc = '{}_{}_{}_{}'. \
            format(args['dataset'], args['model'], args['action'], args['desc'])
        self.model_pkl = self.model_desc + '.ckpt'

        # logger setup
        self.pblog = get_pblog()
        self.pblog.total = self.epoch
        self.tblog = SummaryWriter(join(args['tb_dir'], self.model_desc))

        # model setup
        self.action = get_action(args)
        self.model = Se_resnet(2)


    def __del__(self):
        if hasattr(self, 'tb_log'):
            self.tblog.close()

    def train(self):
        self.pblog.info(self.model_desc)
        optimizer = None
        for epoch in range(self.epoch):
            # get optimizer
            temp = self.action.update_opt(epoch, self.model, self.opt_type,
                                          self.lr, self.lr_epoch)
            if temp is not None:
                optimizer = temp

            self.model.train()
            loss_l = []
            loss_n = []
            dl_len = len(self.dataloader.train)
            ll = len(self.action.eval_legend)
            c_right = np.zeros(ll, np.float32)
            c_sum = np.zeros(ll, np.float32)
            main_loss = 0
            for idx, item in enumerate(self.dataloader.train):
                tx, ty = item['image'], item['label']
                #tx, ty = tx.cuda(non_blocking=True), ty.cuda(non_blocking=True)
                # get network output logits
                logits = self.action.cal_logits(tx, self.model)
                # cal loss
                loss = self.action.cal_loss(ty, logits, self.weight)
                # cal acc
                right_e, sum_e = self.action.cal_eval(ty, logits)
                # backward
                optimizer.zero_grad()
                loss[0].backward()
                optimizer.step()

                c_right += right_e
                c_sum += sum_e
                loss_l.append([ii.item() for ii in loss])
                loss_n.append(ty.size(0))
                main_loss += loss[0].item()
                self.pblog.pb(idx, dl_len, 'Loss: %.5f | Acc: %.3f%%' % (
                    main_loss / (idx + 1), c_right / c_sum))
            loss_l = np.array(loss_l).T
            loss_n = np.array(loss_n)
            loss = (loss_l * loss_n).sum(axis=1) / loss_n.sum()
            c_res = c_right / c_sum

            msg = 'Epoch: {:>3}'.format(epoch)
            loss_scalars = self.action.cal_scalars(loss,
                                                   self.action.loss_legend, msg,
                                                   self.pblog)
            self.tblog.add_scalars('loss', loss_scalars, epoch)

            msg = 'train->   '
            acc_scalars = self.action.cal_scalars(c_res,
                                                  self.action.eval_legend, msg,
                                                  self.pblog)
            self.tblog.add_scalars('eval/train', acc_scalars, epoch)

            if not self.no_eval:
                if not self.personal_eval:
                    with torch.no_grad():
                        self.eval(epoch)
                else:
                    with torch.no_grad():
                        self.eval_personal(epoch)

        path = os.path.join(self.model_dir, self.model_desc)
        self.action.save_model(self.ism, self.model, path, self.eval_best,
                               self.eval_best_epoch)
        self.pblog.debug('Training completed, save the last epoch model')
        temp = 'Result, Best: {:.2f}%, Epoch: {}'.format(self.eval_best,
                                                         self.eval_best_epoch)
        self.tblog.add_text('best', temp, self.epoch)
        self.pblog.info(temp)

    def eval(self, epoch):
        self.model.eval()
        ll = len(self.action.eval_legend)
        c_right = np.zeros(ll, np.float32)
        c_sum = np.zeros(ll, np.float32)
        dl_len = len(self.dataloader.eval)
        labels = []
        predictions = []
        for idx, item in enumerate(self.dataloader.eval):
            x, y = item['image'], item['label']
            #x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            logits = self.action.cal_logits(x, self.model)
            right_e, sum_e = self.action.cal_eval(y, logits)
            c_right += right_e
            c_sum += sum_e
            labels.extend(y.cpu().data)
            predictions.extend(logits.argmax(1).cpu().data)
            self.pblog.pb(idx, dl_len, 'Acc: %.3f %%' % (c_right / c_sum))
        msg = 'eval->    '
        c_res = c_right / c_sum
        acc_scalars = self.action.cal_scalars(c_res, self.action.eval_legend,
                                              msg, self.pblog)
        self.tblog.add_scalars('eval/eval', acc_scalars, epoch)

        # if self.save_cm:
        #     cm_figure = self.action.log_confusion_matrix(labels, predictions,
        #                                                  self.dataloader.class_names)
        #     self.tblog.add_figure('Confusion Matrix', cm_figure, epoch)

        if c_res[0] > self.eval_best and epoch > 30:
            self.eval_best_epoch = epoch
            self.eval_best = c_res[0]
            path = os.path.join(self.best_model_dir, 'Best_' + self.model_desc)
            self.action.save_model(self.ism, self.model, path, self.eval_best,
                                   self.eval_best_epoch)
            self.pblog.debug('Update the best model')

    def eval_personal(self, epoch):
        self.model.eval()
        ll = len(self.action.eval_legend)
        c_right = np.zeros(ll, np.float32)
        c_sum = np.zeros(ll, np.float32)
        dl_len = len(self.dataloader.eval)

        labels = []
        predictions = []
        preindex = None
        prelabel = None
        personal_vote = [0. for i in range(2)]
        class_correct = list(0. for i in range(2))
        class_total = list(0. for i in range(2))
        for idx, item in enumerate(self.dataloader.eval):
            x, y, img_names = item['image'], item['label'], item['img_name']
            #x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            logits = self.action.cal_logits(x, self.model)
            _, prediction = torch.max(logits.data, 1)
            for i, name in enumerate(img_names):
                index, *_ = name.split('_')
                # init pre
                if preindex is None:
                    preindex = index
                if prelabel is None:
                    prelabel = y[0]

                if index != preindex:
                    if personal_vote[0] >= personal_vote[1]:
                        predictions.append(0)
                        if prelabel == 0:
                            class_correct[0] += 1
                    else:
                        predictions.append(1)
                        if prelabel == 1:
                            class_correct[1] += 1
                    labels.append(prelabel.item())
                    class_total[prelabel] += 1
                    personal_vote = [0. for i in range(2)]
                    preindex = index
                    prelabel = y[i]
                    personal_vote[prediction[i]] += 1

                else:
                    personal_vote[prediction[i]] += 1

            self.pblog.pb(idx, dl_len, 'Sen: %.3f %%  | Spe: %.3f %%' % (
                100 * class_correct[0] / (class_total[0] + 1e-6), 100 * class_correct[1] / (class_total[1] + 1e-6)))

        # deal the last patient
        if personal_vote[0] >= personal_vote[1]:
            predictions.append(0)
            if prelabel == 0:
                class_correct[0] += 1
        else:
            predictions.append(1)
            if prelabel == 1:
                class_correct[1] += 1

        labels.append(prelabel.item())
        class_total[prelabel] += 1

        msg = 'eval->    '
        c_res = [100 * class_correct[0] / class_total[0], 100 * class_correct[1] / class_total[1]]
        acc_scalars = self.action.cal_scalars(c_res, self.action.eval_personal_legend,
                                              msg, self.pblog)
        self.tblog.add_scalars('eval/eval', acc_scalars, epoch)

        # if self.save_cm:
        #     cm_figure = self.action.log_confusion_matrix(labels, predictions,
        #                                                  self.dataloader.class_names)
        #     self.tblog.add_figure('Confusion Matrix', cm_figure, epoch)

        if c_res[0] > self.eval_best and epoch > 30 and class_correct[1] / class_total[1] >= 0.85:
            self.eval_best_epoch = epoch
            self.eval_best = c_res[0]
            path = os.path.join(self.best_model_dir, 'Best_' + self.model_desc)
            self.action.save_model(self.ism, self.model, path, self.eval_best,
                                   self.eval_best_epoch)
            self.pblog.debug('Update the best model')

In [None]:
data_dir = ''
config = {
    'optimizer': 'sgd',
    'dataset': 'ImageNet100',
    'img_size': 224,
    'model': 'Se_resnet',
    'dataset': 'Gallbladder',
    'lr': 0.01,
    'lr_epoch': 35,
    'epoch': 100,
    'pre_train': False,
    'batch_size': 64,
    'weight': [5.0, 1.0],
    'save_cm': False,
    'log_dir': join(data_dir, 'log'),
    'loss_dir': join(data_dir, 'loss'),
    'model_dir': join(data_dir, 'model'),
    'best_model_dir': join(data_dir, 'best_model'),
    'step_model_dir': join(data_dir, 'step_temp'),
    'tb_dir': join(data_dir, 'tb'),
    'action': 'base',
    'desc': '',
    'cmd': 'train',
    'log_name': 'Modelo base',
    'cuda': '0',
    'num_workers': 2,
    'no_eval': False,
    'personal_eval': False
}

ssl._create_default_https_context = ssl._create_unverified_context
pblog = get_pblog(total = 100)
Trainer(config).train()

                                                                            

[1;32m[2024-04-04 15:06:21] @Modelo base INFO    Gallbladder_Se_resnet_base_[1;0m
[1;32m[2024-04-04 15:06:21] @Modelo base INFO    Gallbladder_Se_resnet_base_[1;0m


 [>............................................]Tot: 0ms | Loss: 0.69673 | Acc: 81.250%      1/232 

  self.pblog.pb(idx, dl_len, 'Loss: %.5f | Acc: %.3f%%' % (


                                                                            

[1;32m[2024-04-04 15:36:20] @Modelo base INFO    Epoch:   0| loss: 0.55861174[1;0m
[1;32m[2024-04-04 15:36:20] @Modelo base INFO    Epoch:   0| loss: 0.55861174[1;0m


                                                                            

[1;32m[2024-04-04 15:36:20] @Modelo base INFO    train->   | acc: 66.019%[1;0m
[1;32m[2024-04-04 15:36:20] @Modelo base INFO    train->   | acc: 66.019%[1;0m


 [>............................................]Tot: 0ms | Acc: 62.500 %                     1/53 

  self.pblog.pb(idx, dl_len, 'Acc: %.3f %%' % (c_right / c_sum))


                                                                            

[1;32m[2024-04-04 15:38:23] @Modelo base INFO    eval->    | acc: 60.642%[1;0m
[1;32m[2024-04-04 15:38:23] @Modelo base INFO    eval->    | acc: 60.642%[1;0m


 [>............................................]Tot: 0ms | Loss: 0.52100 | Acc: 68.750%      1/232 

  self.pblog.pb(idx, dl_len, 'Loss: %.5f | Acc: %.3f%%' % (


