In [1]:
from osgeo import gdal
import numpy as np
import os
import cv2

In [2]:
#  Read tif data set
def readTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "文件无法打开")
    #  The number of columns of the grid matrix
    width = dataset.RasterXSize 
    #  The number of rows of the grid matrix
    height = dataset.RasterYSize 
    #  The number of channels
    bands = dataset.RasterCount 
    #  get dataset
    if(data_width == 0 and data_height == 0):
        data_width = width
        data_height = height
    data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
    #  Get affine matrix information
    geotrans = dataset.GetGeoTransform()
    #  Get projection information
    proj = dataset.GetProjection()
    return width, height, bands, data, geotrans, proj

#  Save tif file function
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    #创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if(dataset!= None):
        dataset.SetGeoTransform(im_geotrans) #Write affine transformation parameters
        dataset.SetProjection(im_proj) #Write projection
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

#### PATH parameter

In [3]:
train_image_path = r"comp\train\vi"
train_label_path = r"comp\train\vl"

# save_image_path = r"comp\train\image_boost"
# save_label_path = r"comp\train\label_boost"

#### Image enhancement main

In [4]:
imageList = os.listdir(train_image_path)
labelList = os.listdir(train_label_path)
tran_num = len(imageList) + 1
for i in range(len(imageList)):
    #  image
    img_file = train_image_path + "\\" + imageList[i]
    im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(img_file)
    #  label
    label_file = train_label_path + "\\" + labelList[i]
    label = cv2.imread(label_file)

    #  flip image horizontally
    im_data_hor = np.flip(im_data, axis = 2)
    hor_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_hor, im_geotrans, im_proj, hor_path)
    #  flip label horizontally
    Hor = cv2.flip(label, 1)
    hor_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]
    cv2.imwrite(hor_path, Hor)
    tran_num += 1
    
    #  flip image vertically
    im_data_vec = np.flip(im_data, axis = 1)
    vec_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_vec, im_geotrans, im_proj, vec_path)
    #  flip label vertically
    Vec = cv2.flip(label, 0)
    vec_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]
    cv2.imwrite(vec_path, Vec)
    tran_num += 1
    
    #  flip image diagonally
    im_data_dia = np.flip(im_data_vec, axis = 2)
    dia_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_dia, im_geotrans, im_proj, dia_path)
    #  flip label diagonally
    Dia = cv2.flip(label, -1)
    dia_path = train_label_path + "\\" + str(tran_num) + labelList[i][-4:]
    cv2.imwrite(dia_path, Dia)
    tran_num += 1