In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 用来解决一些奇怪的错误

from pathlib import Path

from osgeo import gdal
import numpy as np

from utils.data_io import get_value_from_raster

from utils.data_io import get_csv
from utils.property import ICESAT2Properties


def get_value_from_raster_gdal(raster_path, longitudes, latitudes, band_indexs=None):
    ds = gdal.Open(str(raster_path))
    gt = ds.GetGeoTransform()
    inv_gt = gdal.InvGeoTransform(gt)
    values = []
    if band_indexs is None:
        band_indexs = list(range(1, ds.RasterCount + 1))
    else :
        band_indexs = [band_indexs] if isinstance(band_indexs, int) else band_indexs
    for band_index in band_indexs:
        band = ds.GetRasterBand(band_index)
        for lon, lat in zip(longitudes, latitudes):
            px, py = gdal.ApplyGeoTransform(inv_gt, lon, lat)
            px, py = int(round(px)), int(round(py))
            try:
                value = band.ReadAsArray(px, py, 1, 1)[0, 0]
            except Exception:
                value = np.nan
            values.append(value)
    ds = None
    return np.array(values).reshape(len(band_indexs),len(longitudes) )

In [None]:
data = get_csv(Path("keypoints_center.csv"))
# 提取部分数据
ds = data[
    [
        "Latitude (deg)",
        "Longitude (deg)",
        "UTM Easting (m)",
        "UTM Northing (m)",
        "Height (m MSL)",
        "Height (m HAE)",
        ICESAT2Properties.AlongTrack.value,
    ]
]

real_heights = ds.loc[:, "Height (m MSL)"].values

In [None]:
from utils.data_io import reproject2


s2a_path = Path(
    "data\sentinel-2\subset_1_of_S2A_MSIL2A_20250106T031121_N0511_R075_T49QCD_20250106T061847_s2resampled.tif"
)

s2a_wgs84_path = Path(
    "data\sentinel-2\subset_1_of_S2A_MSIL2A_20250106T031121_N0511_R075_T49QCD_20250106T061847_s2resampled_wgs84.tif",
)

if not s2a_wgs84_path.exists():
    reproject2(s2a_path, s2a_wgs84_path, epsg=4326)


raster_data = get_value_from_raster_gdal(
    s2a_wgs84_path,
    ds["Longitude (deg)"].values,
    ds["Latitude (deg)"].values,
)


print(raster_data.shape)

(12, 119)




In [4]:
dem_path = Path("data/dem/GMRT_resample.tif")

real_height = []
# with rasterio.open(dem_path, mode="r") as src:
#     values = get_value_from_raster(
#         src, ds["Longitude (deg)"].values, ds["Latitude (deg)"].values, index=1
#     )
#     real_height.extend(values)

real_height = get_value_from_raster_gdal(
    dem_path, ds["Longitude (deg)"].values, ds["Latitude (deg)"].values,
)

# 拼接数据集

In [5]:
from utils.data_io import save_csv


add_data = {}

for i in range(raster_data.shape[0]):
    "第9波段是8A，没有第10波段"
    i = i + 1
    if i < 9:
        band_num = str(i)
    elif i == 9:
        band_num = "8A"
    elif i == 10:
        band_num = "9"
    elif i > 10:
        band_num = str(i)
    else:
        raise ValueError("Unexpected band index")

    add_data["B" + band_num] = raster_data[i - 1, :]

add_data["real_height"] = real_heights

print(f"add_data index:{list(add_data.keys())}")

ml_data = ds.assign(**add_data)
print(ml_data.columns)

# 删除缺失值
ml_data = ml_data.dropna()

dp = Path("dataset.csv")

print(f"shape of ml_data: {ml_data.shape}")
save_csv(ml_data, data_path=dp, backup=True, overwrite=True)

add_data index:['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12', 'real_height']
Index(['Latitude (deg)', 'Longitude (deg)', 'UTM Easting (m)',
       'UTM Northing (m)', 'Height (m MSL)', 'Height (m HAE)',
       'Along-Track (m)', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8',
       'B8A', 'B9', 'B11', 'B12', 'real_height'],
      dtype='object')
shape of ml_data: (119, 20)
Overwrite operation cancelled by user.


False