In [1]:
# RUN THIS CELL FIRST!

import albumentations as A
import argparse
import cv2
import numpy as np
import os
import pandas as pd
import random
import re
import shutil
import torchvision.transforms as transforms
import xml.dom.minidom
import zipfile
from matplotlib import pyplot as plt
from PIL import Image
from PIL import ImageEnhance
from xml.dom.minidom import parse

In [2]:
# ---------------PARAMS------------------
root_dst_path = r'/root/yolov5/dataset/'
imgDst = f'/root/yolov5/dataset/images/'
annoDst = f'/root/yolov5/dataset/annotation/'
imgPath = r'/root/CMDI_data/image/'
annoPath = r'/root/CMDI_data/xml/'
imgDst = f'/root/yolov5/dataset/images/'
annoDst = f'/root/yolov5/dataset/annotation/'
txtSavePath = r'/root/partitionSet/'
txtlist = ['BBU3900', 'FSIH', 'B8300', 'Airscale', 'BBU3910', 'BB6648', 'BBU5900', 'BB6630', 'V9200']
partitionSavePath = txtSavePath + 'partition/'
partitionSet = [partitionSavePath + 'set1.txt', 
                partitionSavePath + 'set2.txt',  
                partitionSavePath + 'set3.txt',  
                partitionSavePath + 'set4.txt',  
                partitionSavePath + 'set5.txt']
# ---------------------------------------

# clear folders
def clear(folders):
    for kind in folders:
        filepath = root_dst_path + kind
        shutil.rmtree(filepath)
        os.mkdir(filepath)

def doAug(imgName, annoName, newImgName, newAnnoName):
    t = transform()
    t.open(imgName, annoName)
    if times > 20: t.augment_high()
    else: t.augment_low()
    t.close(newImgName, newAnnoName)

class transform:
    def __init__(self):
        """
        init
        :param img：img图像（路径）
        :param image：img图像（图像本身）
        :param xml: xml文档（路径）
        :param bboxes: xml中bbox
        """
        self.bboxes = [] # [xmin, ymin, xmax, ymax, type]
        self.is_opened = False
        
    def open(self, img, xml):
        """
        打开img和xml
        :param img：图像文件
        :param xml: xml文件
        :return:
        """
        self.is_opened = True

        #读取img
        self.img = img
        self.image = cv2.cvtColor(cv2.imread(self.img), cv2.COLOR_BGR2RGB)
        # 读取xml
        self.xml = xml
        tree = parse(self.xml)
        # 文档根元素
        root = tree.documentElement
        # 宽度，高度
        self.width = int(root.getElementsByTagName("width")[0].childNodes[0].data)
        self.height = int(root.getElementsByTagName("height")[0].childNodes[0].data)
        # 遍历每一个对象
        objs = root.getElementsByTagName("object")
        for obj in objs:
            name = obj.getElementsByTagName("name")[0].childNodes[0].data
            if name == '接地' or name == '接地线':
                continue
            box = obj.getElementsByTagName("bndbox")[0]
            xmin = int(box.getElementsByTagName("xmin")[0].childNodes[0].data)
            xmax = int(box.getElementsByTagName("xmax")[0].childNodes[0].data)
            ymin = int(box.getElementsByTagName("ymin")[0].childNodes[0].data)
            ymax = int(box.getElementsByTagName("ymax")[0].childNodes[0].data)
            typ = obj.getElementsByTagName("type")[0].childNodes[0].data
            self.bboxes.append([xmin, ymin, xmax, ymax, typ])

    def close(self, imgtarget='', xmltarget=''):
        """
        关闭xml，并将结果写入
        :param target: 写入的目标文件名，留空则写回原文件
        :return:
        """
        if imgtarget == '':
            imgtarget = self.img
        if xmltarget == '':
            xmltarget = self.xml
        self.is_opened = False

        # 保存图像
        cv2.imwrite(imgtarget, cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), 100])
        # 读取xml文本
        tree = parse(self.xml)
        # 文档根元素
        root = tree.documentElement
        # 宽度，高度
        root.getElementsByTagName("width")[0].childNodes[0].data = str(np.size(self.image, 1))
        root.getElementsByTagName("height")[0].childNodes[0].data = str(np.size(self.image, 0))
        # 遍历每一个对象
        objs = root.getElementsByTagName("object")
        i = 0
        for obj in objs:
            name = obj.getElementsByTagName("name")[0].childNodes[0].data
            if name == '接地' or name == '接地线':
                continue
            if i < len(self.bboxes):
                elements = self.bboxes[i]
                box = obj.getElementsByTagName("bndbox")[0]
                assert elements[0] < elements[2] and elements[1] < elements[3], print(self.xml)
                box.getElementsByTagName("xmin")[0].childNodes[0].data = str(int(elements[0]))
                box.getElementsByTagName("ymin")[0].childNodes[0].data = str(int(elements[1]))
                box.getElementsByTagName("xmax")[0].childNodes[0].data = str(int(elements[2]))
                box.getElementsByTagName("ymax")[0].childNodes[0].data = str(int(elements[3]))
                obj.getElementsByTagName("type")[0].childNodes[0].data = str(elements[4])
            else: obj.parentNode.removeChild(obj)
            i += 1
        with open(xmltarget, 'w') as f:
            # 缩进 - 换行 - 编码
            tree.writexml(f, addindent='  ', encoding='utf-8')
        
    def __str__(self):
        assert self.is_opened
        return str(self.points)
    
    def augment_low(self): 
        transform = A.Compose([
                    # image-wise aug
                    A.Flip(p=0.5),
                    A.RandomRotate90(p=0.5),
                    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, p=0.5),
                    A.ImageCompression(quality_lower=75, quality_upper=100, p=1),
                    # pixel-wise aug
                    A.OneOf([
                        A.ISONoise(),
                        A.GaussNoise(),
                    ], p=1),
                    A.MotionBlur(p=0.2),
                    A.HueSaturationValue(p=0.5),
                    A.RandomBrightnessContrast(brightness_limit=0.1, p=0.5),
                    ], bbox_params=A.BboxParams(format='pascal_voc', min_visibility=0.5))
        transformed = transform(image=self.image, bboxes=self.bboxes)
        dropout = A.CoarseDropout(max_holes=10, max_height=50, max_width=50, p=0.1)
        do = dropout(image=transformed['image'])
        self.image = do['image']
        self.bboxes = transformed['bboxes']
        
    def augment_high(self): 
        transform = A.Compose([
                    # image-wise aug
                    A.Flip(p=0.5),
                    A.RandomRotate90(p=0.5),
                    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=20, p=0.5),
                    A.ImageCompression(quality_lower=60, quality_upper=100, p=1),
                    # pixel-wise aug
                    A.OneOf([
                        A.ISONoise(),
                        A.GaussNoise(),
                    ], p=1),
                    A.MotionBlur(p=0.5),
                    A.HueSaturationValue(p=1),
                    A.RandomBrightnessContrast(p=1),
                    ], bbox_params=A.BboxParams(format='pascal_voc', min_visibility=0.5))
        transformed = transform(image=self.image, bboxes=self.bboxes)
        dropout = A.CoarseDropout(max_holes=10, max_height=50, max_width=50, p=0.4)
        do = dropout(image=transformed['image'])
        self.image = do['image']
        self.bboxes = transformed['bboxes']

In [19]:
# classify image-annotation pairs according to their devices
count = 0

file_BBU3900 = open(txtsavepath + 'BBU3900.txt', 'w')
file_FSIH = open(txtsavepath + 'FSIH.txt', 'w')
file_B8300 = open(txtsavepath + 'B8300.txt', 'w')
file_Airscale = open(txtsavepath + 'Airscale.txt', 'w')
file_BBU3910 = open(txtsavepath + 'BBU3910.txt', 'w')
file_BB6648 = open(txtsavepath + 'BB6648.txt', 'w')
file_BBU5900 = open(txtsavepath + 'BBU5900.txt', 'w')
file_BB6630 = open(txtsavepath + 'BB6630.txt', 'w')
file_V9200 = open(txtsavepath + 'V9200.txt', 'w')

imglist = os.listdir(imgPath)
for image in imglist:
    imgname = imgPath + image
    annoname = annoPath + image.replace('.jpg','.xml')
    if annoname == f'/root/CMDI_data/xml/.ipynb_checkpoints': continue
    
    dom = xml.dom.minidom.parse(annoname)
    collection = dom.documentElement
    objectlist = collection.getElementsByTagName('object')
    
    devices = []
    for obj in objectlist:
        if obj.getElementsByTagName('name')[0].childNodes[0].data == '接地': continue
        if obj.getElementsByTagName('name')[0].childNodes[0].data == '接地线': continue
        type = obj.getElementsByTagName('type')[0].childNodes[0].data
        devices.append(type)    
    print(image, set(devices), len(devices))
    
    if '华为BBU3900' in devices:
        file_BBU3900.write(imgname + '\n')
        count += 1
    elif '诺基亚 FSIH' in devices:
        file_FSIH.write(imgname + '\n')
        count += 1
    elif '中兴B8300' in devices:
        file_B8300.write(imgname + '\n')
        count += 1
    elif '诺基亚 Airscale' in devices:
        file_Airscale.write(imgname + '\n')
        count += 1
    elif '华为BBU3910' in devices:
        file_BBU3910.write(imgname + '\n')
        count += 1
    elif '爱立信BB6648' in devices:
        file_BB6648.write(imgname + '\n')
        count += 1
    elif '华为BBU5900' in devices:
        file_BBU5900.write(imgname + '\n')
        count += 1
    elif '爱立信BB6630' in devices:
        file_BB6630.write(imgname + '\n')
        count += 1
    elif '中兴V9200' in devices:
        file_V9200.write(imgname + '\n')
        count += 1
        
assert count == len(os.listdir(imgPath))
file_BBU3900.close()
file_FSIH.close()
file_B8300.close()
file_Airscale.close()
file_BBU3910.close()
file_BB6648.close()
file_BBU5900.close()
file_BB6630.close()
file_V9200.close()

100000263_b11uj182.jpg {'爱立信BB6630'} 2
100000263_gc4iae2b.jpg {'爱立信BB6630'} 1
100000263_4140avcf.jpg {'中兴V9200'} 1
100000263_2hv48jxi.jpg {'华为BBU3910'} 1
100000263_p6uc632e.jpg {'中兴V9200'} 1
100000263_un4c6677.jpg {'爱立信BB6648'} 1
100000263_a3v81n4p.jpg {'华为BBU3910'} 1
100000263_g4hy5572.jpg {'华为BBU5900'} 1
100000263_v1fmgw7w.jpg {'中兴V9200'} 1
100000263_j3ct28q5.jpg {'爱立信BB6630'} 1
100000263_18li3y4f.jpg {'华为BBU5900'} 1
100000263_36u445ce.jpg {'华为BBU5900'} 1
100000263_45v4p81s.jpg {'爱立信BB6630'} 1
100000263_71k60d55.jpg {'爱立信BB6630'} 1
100000263_6m4u2a4b.jpg {'爱立信BB6630'} 2
100000263_5ee74156.jpg {'华为BBU5900'} 1
100000263_0nf2827i.jpg {'华为BBU5900'} 1
100000263_t41qdm3a.jpg {'华为BBU5900'} 1
100000263_1666m8ha.jpg {'中兴V9200'} 1
100000263_oal74s6t.jpg {'华为BBU5900'} 1
100000263_j5gt71a0.jpg {'华为BBU5900'} 1
100000263_607bf6k4.jpg {'爱立信BB6630'} 1
100000263_34r6wigd.jpg {'中兴V9200'} 1
100000263_3s88opn5.jpg {'爱立信BB6630'} 3
100000263_bu7w2po7.jpg {'爱立信BB6630'} 2
100000263_a638y87w.jpg {'中兴V9200'} 

In [22]:
# Allocate to each set averagely within each device
index = 0
k = 5 # k-fold cross validation
counter = np.zeros(k,)
for txt in txtlist:
    with open(txtSavePath + txt + '.txt', 'r') as f:
        filelist = f.readlines()
    for file in filelist:
        partitionSet[index % k].write(file)
        counter[index % k] += 1
        index += 1
    f.close()
    print('Partitioned ' + txt, f', {index} files in total.')
    print(counter)
    counter = np.zeros(k,)
    index = 0
for set_i in partitionSet:
    set_i.close()

Partitioned BBU3900 , 13 files in total.
[3. 3. 3. 2. 2.]
Partitioned FSIH , 20 files in total.
[4. 4. 4. 4. 4.]
Partitioned B8300 , 32 files in total.
[7. 7. 6. 6. 6.]
Partitioned Airscale , 31 files in total.
[7. 6. 6. 6. 6.]
Partitioned BBU3910 , 60 files in total.
[12. 12. 12. 12. 12.]
Partitioned BB6648 , 78 files in total.
[16. 16. 16. 15. 15.]
Partitioned BBU5900 , 758 files in total.
[152. 152. 152. 151. 151.]
Partitioned BB6630 , 623 files in total.
[125. 125. 125. 124. 124.]
Partitioned V9200 , 1078 files in total.
[216. 216. 216. 215. 215.]


In [9]:
# Augment train sets and send to yolov5 dataset
k = 5 # k-fold cross validation
test_set = 5 # awaits modification
n = 2 # augment strength constant
train_set = []
for i in range(k):
    train_set.append(str(i + 1))
train_set.pop(test_set - 1)
print('test: ', test_set, ', train: ', train_set)

counter = np.zeros(9,)
for index in train_set:
    with open(partitionSet[int(index) - 1], 'r') as f:
        imgList = f.readlines()
        f.close()
    for image in imgList:
        imgName = image.replace('\n', '')
        annoName = imgName.replace('.jpg','.xml').replace('image', 'xml')
        if annoName == annoPath + '.ipynb_checkpoints': continue
        dom = xml.dom.minidom.parse(annoName)
        collection = dom.documentElement
        objectlist = collection.getElementsByTagName('object')

        devices = []
        for obj in objectlist:
            if obj.getElementsByTagName('name')[0].childNodes[0].data == '接地': continue
            if obj.getElementsByTagName('name')[0].childNodes[0].data == '接地线': continue
            type = obj.getElementsByTagName('type')[0].childNodes[0].data
            devices.append(type)

        if '华为BBU3900' in devices:
            times = min(50, round(216 / 3 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[0] += 1
            for i in range(times):
                counter[0] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '诺基亚 FSIH' in devices:
            times = min(50, round(216 / 4 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[1] += 1
            for i in range(times):
                counter[1] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '中兴B8300' in devices:
            times = min(50, round(216 / 6 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[2] += 1
            for i in range(times):
                counter[2] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '诺基亚 Airscale' in devices:
            times = min(50, round(216 / 6 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[3] += 1
            for i in range(times):
                counter[3] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '华为BBU3910' in devices:
            times = min(50, round(216 / 12 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[4] += 1
            for i in range(times):
                counter[4] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '爱立信BB6648' in devices:
            times = min(50, round(216 / 16 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[5] += 1
            for i in range(times):
                counter[5] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '华为BBU5900' in devices:
            times = min(50, round(216 / 152 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[6] += 1
            for i in range(times):
                counter[6] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '爱立信BB6630' in devices:
            times = min(50, round(216 / 125 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[7] += 1
            for i in range(times):
                counter[7] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
        elif '中兴V9200' in devices:
            times = min(50, round(216 / 216 * n))
            shutil.copy(imgName, imgDst)
            shutil.copy(annoName, annoDst)
            counter[8] += 1
            for i in range(times):
                counter[8] += 1
                newImgName = imgDst + imgName.replace('.jpg', f'_aug{i+1}.jpg').replace(imgPath, '')
                newAnnoName = annoDst + imgName.replace('.jpg', f'_aug{i+1}.xml').replace(imgPath, '')
                doAug(imgName, annoName, newImgName, newAnnoName)
print(counter)

test:  5 , train:  ['1', '2', '3', '4']
[ 561.  816. 1326. 1275. 1776. 1764. 2428. 1996. 2589.]


In [53]:
# Send test set to yolov5 dataset
k = 5 # k-fold cross validation
test_set = 5 # awaits modification
train_set = []
for i in range(k):
    train_set.append(str(i + 1))
train_set.pop(test_set - 1)
print('test: ', test_set, ', train: ', train_set)


with open(partitionSet[test_set - 1], 'r') as f:
    imgList = f.readlines()
    f.close()
for image in imgList:
        imgName = image.replace('\n', '')
        annoName = imgName.replace('.jpg','.xml').replace('image', 'xml')
        if annoName == annoPath + '.ipynb_checkpoints': continue
        shutil.copy(imgName, imgDst)
        shutil.copy(annoName, annoDst)

test:  5 , train:  ['1', '2', '3', '4']


In [54]:
!python /root/yolov5/split_train_val_test.py

In [55]:
!python /root/yolov5/voc_label.py

/root


In [52]:
clear(['annotation', 'images', 'labels', 'partition/main'])#'annotation', 'images', 'labels', 'partition/main'

In [39]:
print(len(os.listdir(imgDst)), len(os.listdir(annoDst)))
print(len(os.listdir(root_dst_path + 'labels')))

542 542
542


In [27]:
imgList = os.listdir(imgPath)
for image in imgList:
        imgName = imgPath + image
        annoName = annoPath + image.replace('.jpg', '.xml')
        if annoName == annoPath + '.ipynb_checkpoints': continue
        shutil.copy(imgName, imgDst)
        shutil.copy(annoName, annoDst)