In [None]:
## Part 1: utils
from matplotlib import pyplot as plt
from torchvision import transforms
import numpy as np
from torchvision.models import resnet50
import torch
from torch.backends import cudnn
import tqdm
import random
from PIL import Image
import cv2
import numpy as np
import torchvision.transforms.functional as TF
cudnn.benchmark = True

class Average_IOU_Meter():
    def __init__(self, args, batches):
        self.args = args
        self.cal_size = batches - 1
        self.iou_vars = np.zeros( ( self.cal_size, args.num_classes-1, ) )
        self.idx = 0
    def update(self, niov):
        self.iou_vars[self.idx,:] = niov
        self.idx +=1

    def reset(self):
        self.iou_vars = np.zeros(( self.cal_size, self.args.num_classes-1, ))
        self.idx = 0

    def cal_avg(self):
        result = np.zeros( (self.args.num_classes-1, ) )
        for ci in range(self.args.num_classes-1):
            result[ci] = np.nanmean( self.iou_vars[:,ci])

        return result

class data_augment:
    def __init__(self, args, split = 'train'):
        self.args = args
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        self.scale_factor = 16
        self.base_size = 2048
        self.ignore_label = args.num_classes-1
        if split == "train":
            self.crop_size = (256, 512)  ## 训练时所用的裁剪大小
        else:
            self.crop_size = (256, 512)

        self.label_mapping = {-1: self.ignore_label, 0: self.ignore_label,
                              1: self.ignore_label, 2: self.ignore_label,
                              3: self.ignore_label, 4: self.ignore_label,
                              5: self.ignore_label, 6: self.ignore_label,
                              7: 0,  # road
                              8: 1,  # sidewalk
                              9: self.ignore_label,
                              10: self.ignore_label,
                              11: 2, # building
                              12: self.ignore_label,
                              13: self.ignore_label,
                              14: self.ignore_label,
                              15: self.ignore_label,
                              16: self.ignore_label,
                              17: 3, # pole
                              18: self.ignore_label,
                              19: 4, # traffic light
                              20: 5, # traffic sign
                              21: 6, # vegetation
                              22: 7, # terrain
                              23: 8, # sky
                              24: 9, # person
                              25: 10, # rider
                              26: 11, # car
                              27: self.ignore_label,
                              28: self.ignore_label,
                              29: self.ignore_label,
                              30: self.ignore_label,
                              31: self.ignore_label,
                              32: self.ignore_label,
                              33: self.ignore_label}


    def pad_image(self, image, h, w, size, padvalue):
        pad_image = image.copy()
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)
        if pad_h > 0 or pad_w > 0:
            pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
                                           pad_w, cv2.BORDER_CONSTANT,
                                           value=padvalue)

        return pad_image

    def rand_crop(self, image, label):
        """

        :param image: [H, W, C]
        :param label: [H, W]
        :return:
        """

        h, w = image.shape[:-1]
        image = self.pad_image(image, h, w, self.crop_size,
                          (0.0, 0.0, 0.0))
        label = self.pad_image(label, h, w, self.crop_size,
                          (self.ignore_label,))

        new_h, new_w = label.shape
        x = random.randint(0, new_w - self.crop_size[1])
        y = random.randint(0, new_h - self.crop_size[0])
        image = image[y:y + self.crop_size[0], x:x + self.crop_size[1]]
        label = label[y:y + self.crop_size[0], x:x + self.crop_size[1]]

        return image, label

    def multi_scale_aug(self,image, label=None,
                        rand_scale=1, rand_cro=True):
        long_size = np.int(self.base_size * rand_scale + 0.5)
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = np.int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = np.int(h * long_size / w + 0.5)

        image = cv2.resize(image, (new_w, new_h),
                           interpolation=cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h),
                               interpolation=cv2.INTER_NEAREST)
        else:
            return image

        if rand_cro:
            image, label = self.rand_crop(image, label)

        return image, label

    def label_transform(self, label):
        return np.array(label).astype('int32')

    def gen_sample_resize(self, img, label):
        rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0
        image, label = self.multi_scale_aug(img, label,
                                            rand_scale=rand_scale)
        image = image.transpose((2, 0, 1))
        return image, label

    def gen_sample_simplified(self,image, target):
        image = torch.from_numpy(cv2.resize(image,dsize=[138,266]))
        target = torch.from_numpy(cv2.resize(target,dsize=[138,266]))
        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(128, 256))
        image = TF.crop(image, i, j, h, w)
        target = TF.crop(target, i, j, h, w)
        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            target = TF.hflip(target)

        return image, target

    def gen_sample(self, image, label, is_flip=True):
        """
        this function is used to generate samples with combination of transforms
        :param image:
        :param label:
        :param is_flip:
        :return:
        """
        image, label = self.rand_crop(image, label)
        image = image.transpose((2,0,1))

        return image, label

    def input_transform(self, image):
        image = image.astype(np.float32)[:, :, ::-1]
        image = image / 255.0
        image -= self.mean
        image /= self.std
        return image

    def convert_label(self,label, inverse=False):
        temp = label.copy()
        if inverse:
            for v, k in self.label_mapping.items():
                label[temp == k] = v
        else:
            for k, v in self.label_mapping.items():
                label[temp == k] = v
        return label

class data_sampler:
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def cv2_sampler(self, index) :
        image = cv2.imread(self.images[index], cv2.IMREAD_COLOR)
        label = cv2.imread(self.labels[index], cv2.IMREAD_GRAYSCALE)
        return image, label

def class_to_rgb(pred):
    """

    :param pred: 输入图片的张量
    :return:
    """
    mapping = {
        0: (0,0,0),
        1: (0, 255, 0),
        2: (0, 0, 255),
        3: (0, 255, 255),
        4: (128, 0, 0),
        5: (0, 128, 0),
        6: (0, 0, 128),
        7: (128, 128, 0),
        8: (128, 0, 128),
        9: (128, 128, 128),
        10: (64, 0, 0),
        11: (0, 64, 0),
        12: (255, 255, 0),
        13: (0, 128, 128),
        14: (0,0,64),
        15:(255,255,255)

    }

    rgbimg = torch.zeros((3, pred.shape[0], pred.shape[1]), dtype=torch.uint8)
    for k in mapping.keys():
        rgbimg[0][pred == k] = mapping[k][0]
        rgbimg[1][pred == k] = mapping[k][1]
        rgbimg[2][pred == k] = mapping[k][2]

    return rgbimg
def visual(tensor):
    unloader = transforms.ToPILImage()
    img = unloader(tensor)
    plt.imshow(img)
    plt.pause(4)

def reconstruct_mask(re_key:list, range_cls:list):
    re_list = list(set(range_cls) - set(re_key))
    re_dic = list(zip(re_list, [i for i in range(len(re_list))]))
    re_key_dic = [(i, len(re_list) ) for i in re_key]
    return( re_dic + re_key_dic)

def manual_inference_time(model,batch):
    """
    measuring inference time each batch(CUDA version)
    :param model:
    :param dummy_input:
    :param batch:
    :return:
    """
    device = 'cuda:0'
    model = model.to(device)
    repetitions = 300
 #   dummy_input = torch.rand(1, 3, 256, 256).to(device)
    batch = batch.to(device)
    model = model.to(device)

    print('warm up ...\n')
    with torch.no_grad():
        torch.cuda.synchronize()
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        starter.record()
        output = model(batch)
        ender.record()
        torch.cuda.synchronize()  # 等待GPU任务完成
        curr_time = starter.elapsed_time(ender)  # 从 starter 到 ender 之间用时,单位为毫秒
    return

    avg = timings.sum() / repetitions
    print('\navg={}\n'.format(avg))

In [None]:
## Part 2: 准备数据集

import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import numpy as np
import torch

class CityscapesDataset(Dataset):

    def __init__(self, root, args,split='train', mode='fine', augment=True):

        self.root = os.path.expanduser(root)
        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.split = split
        self.augment = augment
        self.images = []
        self.targets = []
        self.args =args
        # =============================================
        # Check that inputs are valid
        # =============================================
        if mode not in ['fine', 'coarse']:
            raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')
        if mode == 'fine' and split not in ['train', 'test', 'val']:
            raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test" or split="val"')
        elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
            raise ValueError(
                'Invalid split for mode "coarse"! Please use split="train", split="train_extra" or split="val"')
        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
            raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                               ' specified "split" and "mode" are inside the "root" directory')

        # =============================================
        # Read in the paths to all images
        # =============================================
        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                if '.ipynb' not in file_name:
                    self.images.append(os.path.join(img_dir, file_name))
                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                                 '{}_labelIds.png'.format(self.mode))
                    # target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], '{}_color.png'.format(self.mode))
                    self.targets.append(os.path.join(target_dir, target_name))
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of images: {}\n'.format(self.__len__())
        fmt_str += '    Split: {}\n'.format(self.split)
        fmt_str += '    Mode: {}\n'.format(self.mode)
        fmt_str += '    Augment: {}\n'.format(self.augment)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        return fmt_str

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):

        # first load the RGB image
        data_slr = data_sampler(self.images, self.targets)
        image, target = data_slr.cv2_sampler(index)

        ## gen_sample function will do the augmentation
        data_handler = data_augment(self.args, self.split)
        target = data_handler.convert_label(target)
    #    image, target = data_handler.gen_sample_resize(image, target)
    #    image, target = data_handler.gen_sample_simplified(image, target)
        image, target = data_handler.gen_sample(image, target)

        image = torch.from_numpy(image.copy()).float()

        target = torch.from_numpy(np.array(target, dtype=np.uint8))
        # print(target)
        # utils.visual(target)
        target = target.long()

        return image, target



def train_loader(dataset, batch_size, num_workers=4, pin_memory=False, normalize=None):

    return data.DataLoader(
        dataset= dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory)



def test_loader(dataset, batch_size, num_workers=4, pin_memory=False):

    return data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory)

print("Done.")

In [None]:
## Part 3: 测试指标
import numpy as np
def cal_confusion_matrix(output, target, n_class):  ##横轴:Gt， 竖轴:Pred
    mask = (output >= 0) & (target < n_class)
    hist = np.bincount(
        n_class * target[mask].astype(int) +
        output[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


def cal_pa(mask, n ):
    return np.nanmean( np.diag(mask)[:(n-1)]/( np.sum(mask,0)[:(n-1)] ) )
def cal_iou(mask, n, iou_vars):
    iou_var = np.diag(mask)[:(n-1)]/( np.sum(mask,0)[:(n-1)]+np.sum(mask,1)[:(n-1)]-np.diag(mask)[:(n-1)] )
    result = np.nanmean( iou_var )
    iou_vars += iou_var
    # print("IOU_VAR", iou_var)
    # print("Gt: ", np.sum(mask,0)[:(n-1)])
    # print("Pred: ", np.sum(mask,1)[:(n-1)])
    return result, iou_vars

print("Done")

In [None]:
## Part 4: 训练pipeline
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from timm.utils import AverageMeter
from torchvision import datasets
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torchvision.utils as v_util
import time

class Manager(object):
    """Handles training and pruning."""

    def __init__(self, args, model):
        self.args = args
        self.cuda = args.cuda
        self.model = model
        self.savename = args.savename
        self.epoch_idx = 0

        train_data = CityscapesDataset(root="./data", args = self.args, split='train')

        test_data = CityscapesDataset(root='./data', args = self.args, split='val')


        self.train_loader = train_loader(train_data, batch_size=args.batch_size, num_workers=args.num_workers)
        self.test_loader = test_loader(test_data, batch_size=args.batch_size, num_workers=args.num_workers)

        self.criterion = nn.CrossEntropyLoss(ignore_index=15)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr)

        self.train_loss = AverageMeter()
        self.test_loss = AverageMeter()
        self.test_miou = AverageMeter()
        self.test_pa = AverageMeter()
        self.inference_time = AverageMeter()
        self.fin_ious = Average_IOU_Meter(self.args, batches= len(self.test_loader))

        self.best_accuracy = 0

        self.list_train_loss = []
        self.list_test_loss = []
        self.list_test_pa = []
        self.list_test_miou = []
        self.list_inference_time = []
        self.arr_fin_ious = np.zeros( (args.epoches, args.num_classes -1 ))
        self.iou_vars = np.zeros( (args.batch_size, args.num_classes-1 ) )
        self.ious = np.zeros( (args.num_classes-1, ) )
        self.writer = SummaryWriter('./logs')
        
        if args.intel_optimize:
            import intel_extension_for_pytorch as ipex
            self.model, self.optimizer = ipex.optimize(self.model, optimizer = self.optimizer, dtype=torch.bfloat16)
        
        if args.pytorch_quant == True and args.intel_quant == False:
            self.quant_pyt_helper()
        
        if args.intel_quant:
            self.quant_intel_helper()

    def do_epoch(self, eidx):
        """Trains model for one epoch."""
        self.train_loss.reset()
        for epoch_idx, ( batch, target_mask ) in tqdm(enumerate(self.train_loader),
                                                                total= len(self.train_loader) ):
            self.do_batch( batch, target_mask )

        self.list_train_loss.append(self.train_loss.avg)
        self.writer.add_scalar('train/loss', self.train_loss.avg, eidx)
        self.writer.flush()


    def do_batch(self, batch, label):
        """Runs model for one batch."""
        if self.cuda:
            batch = batch.cuda()
            label = label.cuda()
        self.model.zero_grad()
        output = self.model(batch)
        # print(output[0].shape)
        # print("Output", output[0, :, 0, 0])
        # print("Label", label[0,0,0])
        # print("Output", output[0, :, 0, 0])
        # print(torch.any(torch.isinf(output)))
        # print(torch.any(torch.isinf(label)))
        loss = self.criterion(output, label)
        # print("Loss: ", loss)
        self.train_loss.update(loss.item(), output.shape[0])
        loss.backward()
        self.optimizer.step()

        return loss

    def train(self):
        """Performs training and evaling"""

        if self.args.cuda:
            self.model = self.model.cuda()

        for idx in range(self.args.epoches):
            self.epoch_idx = idx + 1
            print('Epoch: %d' % (self.epoch_idx))

            ## Training
            self.model.train()
            self.do_epoch(self.epoch_idx)
            
            self.save_model()

            ## Testing
            self.eval(idx)

            ## Save model
            if self.test_miou.avg > self.best_accuracy:
                print(" Best model trained at {} epoches.".format( self.epoch_idx))
                self.best_accuracy = self.test_miou.avg
                self.save_model()

            ## collect experimental data

        if self.args.save_loss:
      #      np.save('./saved/train_loss.npy', np.array(self.list_train_loss))
            print(self.list_train_loss)
            np.save('./saved/test_loss.npy', self.list_test_loss)
            np.save('./saved/pa.npy', self.list_test_pa)
            np.save('./saved/iou.npy', self.list_test_miou)

    def cal_iou(self, mask, ci):
        iou_var = np.diag(mask)[:(self.args.num_classes - 1)] / (
                    np.sum(mask, 0)[:(self.args.num_classes - 1)] + np.sum(mask, 1)[:(self.args.num_classes - 1)]
                    - np.diag(mask)[:(self.args.num_classes - 1)])
        result = np.nanmean(iou_var)
        self.iou_vars[ci,:] = iou_var
        # print("IOU_VAR", iou_var)
        # print("Gt: ", np.sum(mask,0)[:(n-1)])
        # print("Pred: ", np.sum(mask,1)[:(n-1)])
        return result

    def eval(self, idx, record = True, visual = False):
        self.model.eval()

        self.test_loss.reset()
        self.test_miou.reset()
        self.test_pa.reset()
        self.fin_ious.reset()
        self.inference_time.reset()

        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        with torch.no_grad(), torch.cpu.amp.autocast():
            for bidx, (batch, label) in enumerate(tqdm(self.test_loader, desc='Eval')):
                if self.cuda:
                    batch = batch.cuda()
                    label = label.cuda()
                
                starter = time.time()
                output = self.model(batch)  ## [batch_size, 30, 512,512]
                ender = time.time()

                curr_time = ender - starter
                self.inference_time.update(curr_time)
                self.test_loss.update( self.criterion(output, label).item() )

                miou = 0
                pa = 0
                for ci in range(output.shape[0]): ## to each img
                    label_pred = np.array ( torch.squeeze( torch.argmax( torch.squeeze(output[ci]), dim= 0) ).to('cpu') )
                #    print("fs", label_pred.shape)
                    if bidx%10 == 0 and visual:
                        visual_tensor = utils.class_to_rgb(label_pred)
                        img_pred = v_util.make_grid(visual_tensor.long()).cpu().numpy()
                        img_label = v_util.make_grid(batch[ci].long()).cpu().numpy()
                        plt.imshow(np.transpose(img_pred, (1, 2, 0)))
                        plt.show()
                        plt.imshow(np.transpose(img_label, (1, 2, 0)))
                        plt.show()

                    label_true = np.array ( label[ci].to('cpu') )
                    mask = cal_confusion_matrix(label_pred, label_true, self.args.num_classes)
                    iou = self.cal_iou(mask,ci)
                    miou += iou
                    pa += cal_pa(mask, self.args.num_classes)

                for i in range(self.args.num_classes-1):
                    self.ious[i] = np.nanmean(self.iou_vars[:,i])

                if bidx != len(self.test_loader)-1:
                    self.fin_ious.update(self.ious)
                self.test_miou.update(miou/output.shape[0], output.shape[0])
                self.test_pa.update(pa/output.shape[0], output.shape[0])

        avg_ivs = self.fin_ious.cal_avg()
        print("avg_iou_var", avg_ivs)
        print("avg_miou", self.test_miou.avg )
        print("avg_pa", self.test_pa.avg)

        if record:
            self.list_test_loss.append(self.test_loss.avg)
            self.list_test_miou.append(self.test_miou.avg)
            self.list_test_pa.append(self.test_pa.avg)
            self.list_inference_time.append(self.inference_time)
            self.arr_fin_ious[idx, :] = avg_ivs

            self.writer.add_scalar('test/loss', self.test_loss.avg , self.epoch_idx)
            self.writer.add_scalar('test/PA', self.test_pa.avg, self.epoch_idx)
            self.writer.add_scalar('test/IOU', self.test_miou.avg, self.epoch_idx)
            self.writer.flush()

            print( self.list_test_loss)
            np.save('./saved/test_loss.npy', self.list_test_loss)
            np.save('./saved/pa.npy', self.list_test_pa)
            np.save('./saved/iou.npy', self.list_test_miou)
            np.save('./saved/inference_time.npy', self.list_inference_time)
            np.save('./saved/arr_ious.npy', self.arr_fin_ious)
        
        return self.test_miou.avg

    def save_model(self):
        """Saves model to file."""
        base_model = self.model
        ckpt = {
            'args': self.args,
            'model': base_model,
        }

        # Save to file.
        torch.save(ckpt, self.savename + '.pt')
    
    def quant_intel_helper(self):
        import neural_compressor
        from neural_compressor import PostTrainingQuantConfig
        from neural_compressor import quantization
        conf = PostTrainingQuantConfig(backend="ipex")
        q_model = quantization.fit(self.model,
                                    conf,
                                    calib_dataloader=self.test_loader,
                                    eval_func=self.eval(idx = 0) )
    
    def quant_pyt_helper(self):
        quantization_config = torch.quantization.get_default_qconfig('fbgemm')
        self.model = torch.quantization.quantize_dynamic(
        model=self.model,  
        qconfig_spec= {nn.Conv2d},  
        dtype=torch.qint8)  

print("Done.")

In [None]:
##Part 5:超参数
class Args():
    def __init__(self):
        self.num_classes = 13
        self.cuda = False
        self.batch_size = 8
        self.num_worker = 4
        self.lr = 0.01
        self.epoches = 400 
        self.intel_optimize = False
        self.save_loss = True
        self.savename = './saved/hrnetv2'
        self.intel_quant = False 
        self.pytorch_quant = False
        self.num_workers = 4
    
    def set_save_path(self, save_path):
        self.savename = save_path

In [None]:
## Part 6: 训练
from nets import seg_unet
from nets import hrnet
args = Args()

args.set_save_path('./saved/saved_model/seg_unet')
model_unet = seg_unet.Seg_unet(3,args.num_classes,64)
manager = Manager(args, model)
manager.train()

args.set_save_path('./saved/saved_model/hrnetv2')
model_hrnet = hrnet.HRnet(num_classes=args.num_classes, backbone="hrnetv2_w48", pretrained=True)
manager = Manager(args, model)
manager.train()

In [None]:
## Part 7: 测试
ckpt = torch.load("./saved/hrnetv2.pt")
model = ckpt["model"]
manager = Manager(args, model)
manager.eval(idx= 0 , record=False, visual=False)

In [None]:
## Part 8: 量化

args.intel_quant = True  ## Intel 量化工具测试
ckpt = torch.load("./saved/hrnetv2.pt")
model = ckpt["model"]
manager = Manager(args, model)
manager.eval(idx= 0 , record=False, visual=False)