In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import crack_dataset as DS
import copy
import os
import time
import logging
from tools.focal_loss import focal_binary_cross_entropy
from torch.nn import BCEWithLogitsLoss
print("PyTorch Version: ",torch.__version__)

PyTorch Version:  1.10.0


In [24]:
def getDataSet(patch_size, batch_size, workers=8):
    class Args:
      # dataset_path = "/storage/data/classification_dataset_balanced/"
      dataset_path = "../p2_data/data/classification_dataset_balanced/"
      patch_size = 1
      batch_size = 1
      workers = 1
      def __init__(self, patch_size, batch_size, workers):
        self.patch_size = patch_size
        self.batch_size = batch_size
        self.workers = workers
    args = Args(patch_size, batch_size, workers)
    dataset = DS.CODEBRIM(torch.cuda.is_available(),args)
    dataLoaders = {'train': dataset.train_loader, 'val': dataset.val_loader, 'test':dataset.test_loader}
    return dataLoaders

In [25]:
def get_zennet(arch_path, num_classes, use_SE):
    """
    load the Zen-NAS searched model from stored arch planetext

    :param arch_path: path for model architecture description file.txt
    :param num_classes: the data class number
    :param use_SE: whether to use Squeeze-and-Excitation module
    """
    from ref_codes.ZenNAS.ZenNet import masternet
    
    with open(arch_path, 'r') as fid:
        model_plainnet_str = fid.readline().strip()
    
    model = masternet.PlainNet(num_classes=num_classes, plainnet_struct=model_plainnet_str, use_se=use_SE)
    return model

def get_ZenNet_pretrained(model_name, num_classes):
    from ref_codes.ZenNAS.ZenNet import get_ZenNet
    from ref_codes.ZenNAS.PlainNet import basic_blocks
    model = get_ZenNet(model_name, pretrained=True)
    
    # adjust the last layer to adapt to the new class number
    model.fc_linear = basic_blocks.Linear(in_channels=model.fc_linear.in_channels, out_channels=num_classes)
    return model


In [26]:
def log_creater(output_dir):
    """
    create logger object for registering staffs
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    log_name = '{}.log'.format(time.strftime('%Y-%m-%d-%H-%M'))
    final_log_file = os.path.join(output_dir,log_name)
 
 
    # creat a log
    log = logging.getLogger('train_log')
    log.setLevel(logging.DEBUG)
 
    # FileHandler
    file = logging.FileHandler(final_log_file)
    file.setLevel(logging.DEBUG)
 
    # StreamHandler
    stream = logging.StreamHandler()
    stream.setLevel(logging.DEBUG)
 
    # Formatter
    formatter = logging.Formatter(
        '[%(asctime)s][line: %(lineno)d] ==> %(message)s')
 
    # setFormatter
    file.setFormatter(formatter)
    stream.setFormatter(formatter)

     # addHandler
    log.addHandler(file)
    log.addHandler(stream)
 
    log.info('creating {}'.format(final_log_file))
    return log

In [30]:
def train(root_dir, model, logger, lr_h, lr_l, dataLoaders, num_epochs = 300, resume=False, 
    checkpoint = None, device = "cpu"):
    start_epoch = 1
    optimizer = optim.SGD(model.parameters(), lr=lr_h, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=10, T_mult=2, eta_min=lr_l)
    best_acc_hard = 0.0
    best_acc_soft = 0.0
    criterion = torch.nn.BCELoss()
    # criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0,0.9854830961388338,1.5000748374861417,5.214233129729914,2.1126041872854255,2.0173556568167372], device=torch.device('cuda')))  
    
    save_path_hard = root_dir + '/hard.pth'
    save_path_soft = root_dir + '/soft.pth'
    iters = len(dataLoader['train'])
    if resume:
        path_checkpoint = root_dir + checkpoint  # checkpoint path
        checkpoint = torch.load(path_checkpoint)  # load the checkpoint
        model.load_state_dict(checkpoint['net'])  # load the learnable params
        scheduler.load_state_dict(checkpoint['scheduler'])
        optimizer.load_state_dict(checkpoint['optimizer'])  # load the params for optimizers
        start_epoch = checkpoint['epoch']  # set the start epoch
        best_acc_soft = checkpoint['best_acc_soft']
        best_acc_hard = checkpoint['best_acc_hard']


    for epoch in range(start_epoch, num_epochs+1):  # loop over the dataset multiple times

        if epoch % 20 == 0:
            checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            "epoch": epoch,
            "best_acc_soft": best_acc_soft,
            "best_acc_hard": best_acc_hard
            }
            if not os.path.isdir(root_dir + "/checkpoint"):
                os.mkdir(root_dir + "/checkpoint")
            torch.save(checkpoint, root_dir + '/checkpoint/ckpt_best_%s.pth' %(str(epoch)))

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects_hard = 0
            running_corrects_soft = 0

            for i, sample in enumerate(dataLoaders[phase]):
                inputs, labels = sample
                inputs = inputs.to(device)
                if inputs.shape[0] < 2:  # avoid batch norm bug
                    continue
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(inputs)
                #== original loss
                outputs = torch.sigmoid(outputs)
                loss = criterion(outputs, labels)
                
                #== BCE loss with logits
                # loss = criterion(outputs, labels)
                # outputs = torch.sigmoid(outputs)
            
                #== focal loss
                # loss = focal_binary_cross_entropy(outputs, labels)  # multi-label focal loss
                # outputs = torch.sigmoid(outputs)

                outputs = outputs >= 0.5  # binarizing sigmoid output by thresholding with 0.5
                equality_matrix = (outputs.float() == labels).float()
                hard = torch.sum(torch.prod(equality_matrix, dim=1))
                soft = torch.mean(equality_matrix)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                #adjustment in scheduler
                    scheduler.step(epoch + i / iters)
        
                running_loss += loss.item() * inputs.size(0)
                running_corrects_hard += hard.item()
                running_corrects_soft += soft.item()

            epoch_loss = running_loss / len(dataLoaders[phase].dataset)
            epoch_acc_hard = running_corrects_hard / len(dataLoaders[phase].dataset)
            epoch_acc_soft = running_corrects_soft / len(dataLoaders[phase])
            logger.info('{} Epoch:[{}/{}]\t loss={:.5f}\t acc_hard={:.3f} acc_soft={:.3f} lr={:.7f}'.format\
            (phase, epoch , num_epochs, epoch_loss, epoch_acc_hard, epoch_acc_soft, \
            optimizer.state_dict()['param_groups'][0]['lr'] ))

            # deep copy the model
            if epoch >= 150 and phase == 'val' and epoch_acc_hard > best_acc_hard:
                best_acc_hard = epoch_acc_hard
                #   best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), save_path_hard)

            if epoch >= 150 and phase == 'val' and epoch_acc_soft > best_acc_soft:
                best_acc_soft = epoch_acc_soft
                #   best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), save_path_soft)

    # model = get_zennet('model_scripts/zennet_imagenet1k_flops400M_res224.txt', 6, True)
    model = get_ZenNet_pretrained('zennet_imagenet1k_flops400M_SE_res224')
    
    model.load_state_dict(torch.load(root_dir + '/hard.pth'))
    model.to(device)
    model.eval()
    logger.info("hard:")
    evaluation(dataLoaders, device, model, logger)

    model.load_state_dict(torch.load(root_dir + '/soft.pth'))
    model.to(device)
    model.eval()
    logger.info("soft:")
    evaluation(dataLoaders, device, model, logger)



def evaluation(dataLoaders, device, model, logger):
    # criterion = torch.nn.BCELoss()
    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.0,0.9854830961388338,1.5000748374861417,5.214233129729914,2.1126041872854255,2.0173556568167372], device=torch.device('cuda')))  
    for phase in ['train', 'val', 'test']:
        running_loss = 0.0
        running_corrects_hard = 0
        running_corrects_soft = 0
      
        for i, sample in enumerate(dataLoaders[phase]):
            inputs, labels = sample
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            
            #== original loss
            outputs = torch.sigmoid(outputs)
            loss = criterion(outputs, labels)
             
            #== BCE loss with logits
            # loss = criterion(outputs, labels)
            # outputs = torch.sigmoid(outputs)
            
            #== focal loss
            # loss = focal_binary_cross_entropy(outputs, labels)  # multi-label focal loss
            # outputs = torch.sigmoid(outputs)
            
            outputs = outputs >= 0.5  # binarizing sigmoid output by thresholding with 0.5
            equality_matrix = (outputs.float() == labels).float()
            hard = torch.sum(torch.prod(equality_matrix, dim=1))
            soft = torch.mean(equality_matrix)
            running_loss += loss.item() * inputs.size(0)
            running_corrects_hard += hard.item()
            running_corrects_soft += soft.item()

        epoch_loss = running_loss / len(dataLoaders[phase].dataset)
        epoch_acc_hard = running_corrects_hard / len(dataLoaders[phase].dataset)
        epoch_acc_soft = running_corrects_soft / len(dataLoaders[phase])
        logger.info("{}: loss:{:.5f} acc_soft:{:.3f} acc_hard:{:.3f}".format(phase, epoch_loss, epoch_acc_soft, epoch_acc_hard))     



In [31]:
import numpy as np

logger = log_creater("./train_log")
batch_size = 4
patch_size = 224
dataLoader = getDataSet(patch_size, batch_size)

lr = (1e-2,1e-5)
root_dir = './' + str(batch_size) + '-' + str(patch_size) + '-' + str(lr[0])
if not os.path.isdir(root_dir):
    os.mkdir(root_dir)
logger.info("batch_size:" + str(batch_size))
logger.info("patch_size:" + str(patch_size))
logger.info("learning rate high:" + str(lr[0]))
logger.info("learning rate low:" + str(lr[1]))
# model = get_zennet('model_scripts/zennet_imagenet1k_flops400M_res224.txt', 6, True)
model = get_ZenNet_pretrained('zennet_imagenet1k_flops400M_SE_res224')

# get the model parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'model parameter number is: {params}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    model = model.to(device)
lr_h = lr[0]
lr_l = lr[1]
train(root_dir, model, logger, lr_h, lr_l, dataLoader, num_epochs = 300, resume=False, 
checkpoint = None, device = device)




[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:04:59,774][line: 35] ==> creating ./train_log/2021-12-09-11-04.log
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][line: 12] ==> batch_size:4
[2021-12-09 11:05:00,688][

---debug use_se in SuperResIDWE1K7(24,48,2,48,1)
---debug use_se in SuperResIDWE2K7(48,72,2,72,1)
---debug use_se in SuperResIDWE6K7(72,96,2,88,5)
---debug use_se in SuperResIDWE4K7(96,192,2,168,5)
model parameter number is: 9274958


RuntimeError: DataLoader worker (pid(s) 93726, 93728, 93729, 93730, 93731, 93732, 93733, 93734) exited unexpectedly