In [1]:
import gdal
import numpy as np
import pandas as pd
import os
import geopandas as gpd
import matplotlib.pyplot as plt
import torch
from glob import glob
import shapely
import random
from tqdm import tqdm
from collections import Counter
from pybob.GeoImg import GeoImg
import osr
from pybob.image_tools import create_mask_from_shapefile
import gc
import imageio
import time
import shutil

In [2]:
def get_mask(maskshp,geotrf,shape,p='4326'):
    img=np.zeros((shape[0],shape[1]))
    
    if p=='4326':
        sref = osr.SpatialReference()
        sref.ImportFromEPSG(4326)
        proj = sref.ExportToWkt()
    else:
        proj=p
    
    trf=(geotrf[0], geotrf[1], geotrf[2], geotrf[3], geotrf[4], geotrf[5])
    drv = gdal.GetDriverByName('MEM')
    dst = drv.Create('', shape[1], shape[0], 1, gdal.GDT_UInt16)

    sp = dst.SetProjection(proj)
    sg = dst.SetGeoTransform(trf)

    wa = dst.GetRasterBand(1).WriteArray(img)
    del wa, sg, sp
    img_bob = GeoImg(dst)
    mask = create_mask_from_shapefile(img_bob, maskshp)
    return mask
def read_img(path):
    ds=gdal.Open(path)
    trf=ds.GetGeoTransform()
    img=ds.ReadAsArray()
    return {'trf':trf,'img':img}
def tif_save_bands(img,save_name,trf,p='4326'):
    driver=gdal.GetDriverByName('GTiff')
    new_img=driver.Create(save_name,img.shape[2],img.shape[1],img.shape[0],6,['COMPRESS=LZW','BIGTIFF=YES'])
    new_img.SetGeoTransform(trf)
    
    if p=='4326':
        sref = osr.SpatialReference()
        sref.ImportFromEPSG(4326)
        proj = sref.ExportToWkt()
    else:
        proj=p
    new_img.SetProjection(proj)
    for i in range(img.shape[0]):
        new_img.GetRasterBand(i+1).WriteArray(img[i,:,:])
    new_img.FlushCache()
    del new_img
def tif_save(img,save_name,trf,p='4326'):
    driver=gdal.GetDriverByName('GTiff')
    new_img=driver.Create(save_name,img.shape[1],img.shape[0],1,6,['COMPRESS=LZW','BIGTIFF=YES'])
    new_img.SetGeoTransform(trf)
    
    if p=='4326':
        sref = osr.SpatialReference()
        sref.ImportFromEPSG(4326)
        proj = sref.ExportToWkt()
    else:
        proj=p
    new_img.SetProjection(proj)
    new_img.GetRasterBand(1).SetNoDataValue(0)
    new_img.GetRasterBand(1).WriteArray(img)
    new_img.FlushCache()
    del new_img

# 图像预处理

# 转换图像为Patch

In [3]:
def read_img(path):
    ds=gdal.Open(path)
    trf=ds.GetGeoTransform()
    img=ds.ReadAsArray()
    return {'trf':trf,'img':img}
def get_patch(data,size):
    img=data['img']
    trf=data['trf']
    shape=img.shape
    img=np.pad(img,((0,0),(0,size-shape[1]%size),(0,size-shape[2]%size)),'constant',constant_values=0)
    shape=img.shape
    img=img.reshape([shape[0],int(shape[1]/size),-1,shape[2]])
    shape=img.shape
    img=img.reshape([shape[0],shape[1],shape[2],int(shape[3]/size),-1])
    return {'img':img.transpose((0,1,3,2,4)),'trf':(trf[0]-trf[1]/2,trf[1]*size,trf[2],trf[3]-trf[5]/2,trf[4],trf[5]*size)}

# size=120
for city in ['qqhe','senyang','songyuan']:
    print(city)

    for year in ['2017','2018','2019']:
        path1=f'/ssd/hk/cen_samples/CEN_img/{city}_{year}_band2348.tif'
        path2=f'/ssd/hk/cen_samples/CEN_img/{city}_{year}_bandother.tif'
        # try:    
    
        print(path1)
        data=read_img(path1)
        data2=get_patch(data,120)
        print(data['img'].shape,data2['img'].shape)
        # d=data2['img'].astype(np.int16)
        info={'readme':'trf分别为patch左上角坐标(不是左上角像元的中心坐标)和patch的长度','trf':data2['trf'],'shape':[data['img'].shape,data2['img'].shape]}

        np.save(path1.replace('.tif','_size_120.npy'),data2['img'])
        print('max:',data2['img'].max())
        torch.save(info,path1.replace('.tif',f'_size_120_info.pth'))
        del data,data2
        gc.collect()

        data3=read_img(path2)
        data4=get_patch(data3,60)
        print(data3['img'].shape,data4['img'].shape)
        # d=data2['img'].astype(np.int16)
        info={'readme':'trf分别为patch左上角坐标(不是左上角像元的中心坐标)和patch的长度','trf':data4['trf'],'shape':[data3['img'].shape,data4['img'].shape]}

        np.save(path2.replace('.tif',f'_size_60.npy'),data4['img'])
        print('max:',data4['img'].max())
        torch.save(info,path2.replace('.tif',f'_size_60_info.pth'))
        del data3,data4
        gc.collect()
        # except:
                # print('ERROR',path)

qqhe
/ssd/hk/cen_samples/CEN_img/qqhe_2017_band2348.tif
(4, 27097, 42507) (4, 226, 355, 120, 120)
max: nan
(6, 13549, 21254) (6, 226, 355, 60, 60)
max: nan
/ssd/hk/cen_samples/CEN_img/qqhe_2018_band2348.tif
(4, 27097, 42507) (4, 226, 355, 120, 120)
max: nan
(6, 13549, 21254) (6, 226, 355, 60, 60)
max: nan
/ssd/hk/cen_samples/CEN_img/qqhe_2019_band2348.tif
(4, 27097, 42507) (4, 226, 355, 120, 120)
max: nan
(6, 13549, 21254) (6, 226, 355, 60, 60)
max: nan
senyang
/ssd/hk/cen_samples/CEN_img/senyang_2017_band2348.tif
(4, 18424, 13881) (4, 154, 116, 120, 120)
max: nan
(6, 9212, 6941) (6, 154, 116, 60, 60)
max: nan
/ssd/hk/cen_samples/CEN_img/senyang_2018_band2348.tif
(4, 18424, 13881) (4, 154, 116, 120, 120)
max: nan
(6, 9212, 6941) (6, 154, 116, 60, 60)
max: nan
/ssd/hk/cen_samples/CEN_img/senyang_2019_band2348.tif
(4, 18424, 13881) (4, 154, 116, 120, 120)
max: nan
(6, 9212, 6941) (6, 154, 116, 60, 60)
max: nan
songyuan
/ssd/hk/cen_samples/CEN_img/songyuan_2017_band2348.tif
(4, 15571, 308

In [6]:
def read_img(path):
    ds=gdal.Open(path)
    trf=ds.GetGeoTransform()
    img=ds.ReadAsArray()
    return {'trf':trf,'img':img}
def get_patch(data,size):
    img=data['img']
    trf=data['trf']
    shape=img.shape
    img=np.pad(img,((0,size-shape[0]%size),(0,size-shape[1]%size)),'constant',constant_values=0)
    shape=img.shape
    img=img.reshape([int(shape[0]/size),-1,shape[1]])
    shape=img.shape
    img=img.reshape([shape[0],shape[1],int(shape[2]/size),-1])
    return {'img':img.transpose((0,2,1,3)),'trf':(trf[0]-trf[1]/2,trf[1]*size,trf[2],trf[3]-trf[5]/2,trf[4],trf[5]*size)}

# size=120
for city in ['qqhe','senyang','songyuan']:
    print(city)

    for year in ['2017','2018','2019']:
        for c in ['maize','other','rice','soy']:
            path1=f'/ssd/hk/cen_samples/{city}_{year}_{c}_10m.tif'
            # try:    
        
            print(path1)
            data=read_img(path1)
            data2=get_patch(data,120)
            print(data['img'].shape,data2['img'].shape)
            # d=data2['img'].astype(np.int16)
            # info={'readme':'trf分别为patch左上角坐标(不是左上角像元的中心坐标)和patch的长度','trf':data2['trf'],'shape':[data['img'].shape,data2['img'].shape]}

            np.save(path1.replace('.tif','_size_120.npy'),data2['img'])
            print('max:',data2['img'].max())
            # torch.save(info,path1.replace('.tif',f'_size_120_info.pth'))
            del data,data2
            gc.collect()



qqhe
/ssd/hk/cen_samples/qqhe_2017_maize_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2017_other_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2017_rice_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2017_soy_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2018_maize_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2018_other_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2018_rice_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2018_soy_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2019_maize_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2019_other_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2019_rice_10m.tif
(27097, 42507) (226, 355, 120, 120)
max: 1
/ssd/hk/cen_samples/qqhe_2019_soy_10m.tif
(27

In [4]:

for city in ['qqhe','senyang','songyuan']:
    print(city)

    for year in ['2017','2018','2019']:
        k={}
        print(year)
        path1=f'/ssd/hk/cen_samples/CEN_img/{city}_{year}_band2348_size_120.npy'
        path2=f'/ssd/hk/cen_samples/CEN_img/{city}_{year}_bandother_size_60.npy'
        label=[]
        for c in ['maize','rice','soy','other']:
            label.append(np.load(f'/ssd/hk/cen_samples/{city}_{year}_{c}_10m_size_120.npy').astype(np.int8))
        img1=np.load(path1)
        img2=np.load(path2)
        
        for x in tqdm(range(label[1].shape[0])):
            for y in range(label[1].shape[1]):
                s1=img1[:,x,y]
                s2=img2[:,x,y]
                l=[(label[i][x,y].sum())/14400 for i in range(4)]
                if (np.isnan(s1)).any() or (np.isnan(s2)).any():
                    continue
                else:
                    coord=str(x).zfill(3)+str(y).zfill(3)
                    p=f'/ssd/hk/cen_samples/samples/{city}_{year}_{coord}.npy'
                    np.save(p,[s1,s2])
                    k[p]=l
        torch.save(k,f'/ssd/hk/cen_samples/split/samples_{city}_{year}.pth')



qqhe
2017


  return array(a, dtype, copy=False, order=order, subok=True)
100%|██████████| 226/226 [00:29<00:00,  7.71it/s]


2018


100%|██████████| 226/226 [00:28<00:00,  7.85it/s]


2019


100%|██████████| 226/226 [00:27<00:00,  8.23it/s]


senyang
2017


100%|██████████| 154/154 [00:07<00:00, 21.54it/s]


2018


100%|██████████| 154/154 [00:07<00:00, 21.67it/s]


2019


100%|██████████| 154/154 [00:07<00:00, 21.56it/s]


songyuan
2017


100%|██████████| 130/130 [00:12<00:00, 10.43it/s]


2018


100%|██████████| 130/130 [00:12<00:00, 10.47it/s]


2019


100%|██████████| 130/130 [00:12<00:00, 10.42it/s]


In [25]:
import os
n=0
k={}
for i in glob('/ssd/hk/cen_samples/split/*'):
    f=torch.load(i)
    n+=len(f.keys())
    k.update(f)
    # print(len(f.keys()))
n

179028

In [27]:
len(k.keys())

179028

In [31]:

for city in ['qqhe','senyang','songyuan']:
    v3={}
    fs=glob(f'/ssd/hk/cen_samples/samples/{city}*.npy')
    c=[i[-10:-4] for i in fs]
    c=list(set(c))
    for i in c:
        z=[]
        ll=[]
        for year in ['2017','2018','2019']:
            p=f'/ssd/hk/cen_samples/samples/{city}_{year}_{i}.npy'
            if os.path.exists(p):
                z.append(p)
                ll.append(k[p])
            v3[i]=[z,ll]
    torch.save(v3,f'/ssd/hk/cen_samples/split_v3/{city}.pth')


In [32]:
v2=[[i,k[i]] for i in k.keys()]

In [35]:
random.shuffle(v2)
torch.save(v2,'/ssd/hk/cen_samples/split_v2/sample.pth')

In [43]:
v3_=[]
for city in ['qqhe','senyang','songyuan']:
    v3=torch.load(f'/ssd/hk/cen_samples/split_v3/{city}.pth')
    v3_+=[v3[i] for i in v3.keys()]
random.shuffle(v3_)

In [46]:
torch.save(v3_,'/ssd/hk/cen_samples/split_v3/sample.pth')

In [52]:
torch.save(v2[:53708],'/ssd/hk/cen_samples/split_v2/val.pth')
torch.save(v2[53708:],'/ssd/hk/cen_samples/split_v2/train.pth')

In [55]:
v2=torch.load('/ssd/hk/cen_samples/split_v2/val.pth')+torch.load('/ssd/hk/cen_samples/split_v2/train.pth')
pathes=[i[0] for i in v2]

In [67]:
for i in pathes:
    data=np.load(i,allow_pickle=True)[:2]
    np.save(i,data)


In [78]:
for i in v3:
    l=len(i[0])
    add_l=5-l
    i[0]+=[None]*add_l
    i[1]+=[None]*add_l

In [79]:
torch.save(v3,'/ssd/hk/cen_samples/split_v3/train.pth')

In [83]:
np.load('/ssd/hk/cen_samples/samples/songyuan_2017_055061.npy',allow_pickle=True)[0].shape

(4, 120, 120)

In [84]:
len(i[0])

5

In [85]:
len(i[1])

5

In [5]:
t3=torch.load('/ssd/hk/cen_samples/split_v3/train.pth')
v3=torch.load('/ssd/hk/cen_samples/split_v3/val.pth')

In [8]:
t2=[]
v2=[]

In [9]:
for i in t3:
    for index in range(len(i[0])):
        if i[0][index] is not None:
            t2.append([i[0][index],i[1][index]])

In [10]:
for i in v3:
    for index in range(len(i[0])):
        if i[0][index] is not None:
            v2.append([i[0][index],i[1][index]])

In [11]:
print(len(t2),len(v2))

125322 53706


In [12]:
print(t2[0])
import random
random.shuffle(t2)
random.shuffle(v2)
print(t2[0])

['/ssd/hk/cen_samples/samples/songyuan_2017_055061.npy', [0.26743055555555556, 0.0, 0.021944444444444444, 0.710625]]
['/ssd/hk/cen_samples/samples/qqhe_2018_161102.npy', [0.08347222222222223, 0.2927777777777778, 0.14631944444444445, 0.4774305555555556]]


In [13]:
torch.save(t2,'/ssd/hk/cen_samples/split_v2/train.pth')
torch.save(v2,'/ssd/hk/cen_samples/split_v2/val.pth')

In [14]:
t3

[[['/ssd/hk/cen_samples/samples/songyuan_2017_055061.npy',
   '/ssd/hk/cen_samples/samples/songyuan_2018_055061.npy',
   '/ssd/hk/cen_samples/samples/songyuan_2019_055061.npy',
   None,
   None],
  [[0.26743055555555556, 0.0, 0.021944444444444444, 0.710625],
   [0.2590972222222222, 0.0, 0.028333333333333332, 0.7125694444444445],
   [0.2854861111111111, 0.0, 0.0045138888888888885, 0.71],
   None,
   None]],
 [['/ssd/hk/cen_samples/samples/qqhe_2017_074157.npy',
   '/ssd/hk/cen_samples/samples/qqhe_2018_074157.npy',
   '/ssd/hk/cen_samples/samples/qqhe_2019_074157.npy',
   None,
   None],
  [[0.5543055555555556,
    0.34430555555555553,
    0.005277777777777778,
    0.0961111111111111],
   [0.52375, 0.37833333333333335, 0.0010416666666666667, 0.096875],
   [0.5575694444444445, 0.3463888888888889, 0.0, 0.09604166666666666],
   None,
   None]],
 [['/ssd/hk/cen_samples/samples/senyang_2017_080011.npy',
   '/ssd/hk/cen_samples/samples/senyang_2018_080011.npy',
   '/ssd/hk/cen_samples/samples