# Read Copernicus Sentinel 2 data
The code was written to simplify the reading of satellite images from the Copernicus portal. In particular, it focuses on downloading data for Sentinel 2. The original folders downloaded from the portal: https://scihub.copernicus.eu/dhus/#/home will be placed in the same folder where the code is run (in our case the Jupyter Notebook).

In [None]:
import os, pandas as pd, geopandas as gpd, numpy as np, rasterio as rio, matplotlib.pyplot as plt

from sentinelhub import UtmZoneSplitter
from shapely.geometry import Polygon
from shapely.geometry import box
from rasterio.plot import show
from rasterio.mask import mask
from osgeo import gdal_array
from sklearn import cluster
from osgeo import gdal
from glob import glob

## Build the class to call the images

In [None]:
class Read_rasterio:

    """
    The class is used to load the various paths to the files to be read in order to perform some operations directly inside it and recall the files in a simple way. Some operations have been included such as the NDVI, but can be implemented with additional
    """

    def __init__(self, date=None):
        # get the absolute path from where the code is running
        path = os.getcwd()

        # list all the files with a .jp2 format (the one used from Copernicus)
        list_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(path) for f in filenames if os.path.splitext(f)[1] == '.jp2']

        # get the files name and relative dates to interrogate the data
        list_date = []

        if os.environ.get('OS','') == 'Windows_NT':
            for x in list_files:
                date_ = x.split("_")[2].split('T')[0]
                filename = x.split("\\")[-1].split('.')[0]
                list_date.append([date_, x, filename])
        else:
            for x in list_files:
                date_ = x.split("_")[2].split('T')[0]
                filename = x.split("/")[-1].split('.')[0]
                list_date.append([date_, x, filename])

        # create a dataframe with the list build before
        df_paths = pd.DataFrame(list_date)
        df_paths.columns = ['Date', 'Paths', 'Filename']
        df_paths['Date'] =  pd.to_datetime(df_paths['Date'])

        # transform the Date column to datetime type and sort the column
        df_paths.sort_values('Date', key = lambda x : pd.to_datetime(x, format='%b').dt.month)

        list_years = df_paths['Date'].unique()

        self.list_years = list_years
        self.df_paths = df_paths

        if date != None:
            # filter just for one date
            df_paths_filtered = df_paths[df_paths['Date'] == date].reset_index(drop=True)
            self.df_paths_filtered = df_paths_filtered
            self.B02_10m_path = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B02_10m")].reset_index(drop=True)['Paths'][0]
            self.B03_10m_path = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B03_10m")].reset_index(drop=True)['Paths'][0]
            self.B04_10m_path = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B04_10m")].reset_index(drop=True)['Paths'][0]
            self.B08_10m_path = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B08_10m")].reset_index(drop=True)['Paths'][0]
            self.true_color_path = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("TCI_10m")].reset_index(drop=True)['Paths'][0]
            self.B02_10m_filename = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B02_10m")].reset_index(drop=True)['Filename'][0]
            self.B03_10m_filename = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B03_10m")].reset_index(drop=True)['Filename'][0]
            self.B04_10m_filename = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B04_10m")].reset_index(drop=True)['Filename'][0]
            self.B08_10m_filename = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B08_10m")].reset_index(drop=True)['Filename'][0]
            self.true_color_filename = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("TCI_10m")].reset_index(drop=True)['Filename'][0]
            self.B02_10m_date = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B02_10m")].reset_index(drop=True)['Date'][0]
            self.B03_10m_date = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B03_10m")].reset_index(drop=True)['Date'][0]
            self.B04_10m_date = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B04_10m")].reset_index(drop=True)['Date'][0]
            self.B08_10m_date = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("B08_10m")].reset_index(drop=True)['Date'][0]
            self.true_color_date = \
                df_paths_filtered[df_paths_filtered['Filename'].str.contains("TCI_10m")].reset_index(drop=True)['Date'][0]

    def read_rio_10m(self):
        '''
        A function to call all the files with 10 meters resolution
        '''
        B02_10m_rio = rio.open(self.B02_10m_path, driver='JP2OpenJPEG')
        B03_10m_rio = rio.open(self.B03_10m_path, driver='JP2OpenJPEG')
        B04_10m_rio = rio.open(self.B04_10m_path, driver='JP2OpenJPEG')
        B08_10m_rio = rio.open(self.B08_10m_path, driver='JP2OpenJPEG')
        true_color_rio = rio.open(self.true_color_path, driver='JP2OpenJPEG')

        return B02_10m_rio, B03_10m_rio, B04_10m_rio, B08_10m_rio, true_color_rio

    def ndvi_function(self):

        B04_10m_rio = rio.open(self.B04_10m_path, driver='JP2OpenJPEG')
        B08_10m_rio = rio.open(self.B08_10m_path, driver='JP2OpenJPEG')

        np.seterr(divide='ignore', invalid='ignore')
        red = B04_10m_rio.read(1).astype('float64')
        nir = B08_10m_rio.read(1).astype('float64')
        ndvi = np.where(
            (nir+red)==0.,
            0,
            (nir-red)/(nir+red)
        )

        try:
            ndviImage = rio.open('ndvi_10m.tif', 'r+', driver ='Gtiff',
                                 width=B04_10m_rio.width, height = B04_10m_rio.height,
                                 count=1,
                                 crs=B08_10m_rio.crs,
                                 transform=B04_10m_rio.transform,
                                 dtype='float64')

            ndviImage.write(ndvi, 1)

        except:
            ndviImage = rio.open('ndvi_10m.tif', 'w', driver ='Gtiff',
                                 width=B04_10m_rio.width, height = B04_10m_rio.height,
                                 count=1,
                                 crs=B08_10m_rio.crs,
                                 transform=B04_10m_rio.transform,
                                 dtype='float64')

            ndviImage.write(ndvi, 1)

        ndviImage.close()

        ndvi_read = rio.open('ndvi_10m.tif')
        #ndvi_read.close()

        return ndvi_read

    def clip_raster(self, path_open, path_out, var_geometry):
        with rio.open(path_open) as src:
            out_image, out_transform = mask(src, var_geometry.geometry, crop=True)
            out_meta = src.meta

        # Save clipped imagery
        out_meta.update({"driver": "GTiff",
                         "height": out_image.shape[1],
                         "width": out_image.shape[2],
                         "transform": out_transform})

        try:
            with rio.open(path_out, "r+", **out_meta) as dest:
                dest.write(out_image)
        except:
            with rio.open(path_out, "w", **out_meta) as dest:
                dest.write(out_image)

        read_clip = rio.open(path_out)

        return read_clip

In [None]:
# call the class to see the files available.
band = Read_rasterio()

In [None]:
# check the dates available. This can be useful to call a specific date or just a group
band.list_years

In [None]:
# this is an example on how to use the class to get just one date
band = Read_rasterio('2019-12-23')

### Plot the satellite images to check it

Upload some information to add to the image. In this case I am adding data on municipal and provincial limits.

In [None]:
# path to the shp
provinces_path = "ProvCM01012022/ProvCM01012022_WGS84.shp"
municipalities_path = "Com01012022/Com01012022_WGS84.shp"

# read the shp with geopandas
gdf_provinces = gpd.read_file(provinces_path)
gdf_municipalities = gpd.read_file(municipalities_path)

In [None]:
# example on how read all the bands available for the selected date at 10m resolution (would be possible to add others in the class if necessary)
B02_10m_rio, B03_10m_rio, B04_10m_rio, B08_10m_rio, true_color_rio = band.read_rio_10m()

In [None]:
fig, ax = plt.subplots(figsize=(15, 15))

gdf_provinces.plot(ax=ax, facecolor='None', linewidth=1.6, edgecolor='red', zorder=2)
gdf_municipalities.plot(ax=ax, facecolor='None', linewidth=0.5, edgecolor='white', zorder=1)
gdf_provinces.apply(lambda x: ax.annotate(text=x['DEN_PROV'], xy=x.geometry.centroid.coords[0], color = "white", weight='semibold'), axis=1);

show(true_color_rio.read(), transform=true_color_rio.transform, ax=ax)

### Example NDVI function

In [None]:
ndvi = band.ndvi_function()

In [None]:
# plot the ndvi
fig = plt.figure(figsize=(15, 15))
show(ndvi, cmap='RdYlGn')

## Focus on a specific area

This section wants to give some examples on how it is possible to select some areas without the use of external tools like QGIS

### Check an area by municipalies
Could be more than one or by other types of geometries

In [None]:
# select the municipality of interest
gdf_municipality = gdf_municipalities[gdf_municipalities['COMUNE'] == 'Alessandria'].reset_index(drop=True)

# create a bounding box to cut the rasters on
municipality_box = box(*gdf_municipality.total_bounds)

# create a geodaframe with the bounding box created
gdf_municipality_box = gpd.GeoDataFrame(index=[0], crs=gdf_municipality.crs, geometry=[municipality_box])

In [None]:
# get the dataframe to get the trueColor image path
df_paths_filtered = band.df_paths_filtered

#get just the path from the dataframe
file_path = df_paths_filtered[df_paths_filtered['Filename'].str.contains("TCI_10m")].reset_index(drop=True)['Paths'][0]

In [None]:
# an example on how implement another function inside the class and use it easily to get a clip on a desired area
trueColor_img = band.clip_raster(file_path, 'clip_raster.tif', gdf_municipality_box)

In [None]:
# plot the clip
fig, ax = plt.subplots(figsize=(15, 15))

gdf_municipality.plot(ax=ax, facecolor='None', linewidth=1.6, edgecolor='orange', zorder=2)

show(trueColor_img.read(), transform=trueColor_img.transform, ax=ax)

In [None]:
# get the dataframe to get the trueColor image path
df_paths_filtered = band.df_paths_filtered

#get just the path from the dataframe
file_path = df_paths_filtered[df_paths_filtered['Filename'].str.contains("B02_10m")].reset_index(drop=True)['Paths'][0]

In [None]:
# it can be used to clip any band
clip_B02_10m = band.clip_raster(file_path, 'clip_B02_10m.tif', gdf_municipality_box)

In [None]:
fig, ax = plt.subplots(figsize=(15, 15))

gdf_municipality.plot(ax=ax, facecolor='None', linewidth=1.6, edgecolor='orange', zorder=2)

show(clip_B02_10m.read(), transform=clip_B02_10m.transform, ax=ax)

### Example of NDVI calculation using partially the class

In [None]:
# plot all the wanted bands and clip them
out_image_B2, out_transform_B2 = mask(B02_10m_rio, gdf_municipality_box.geometry, crop=True)
out_image_B3, out_transform_B3 = mask(B03_10m_rio, gdf_municipality_box.geometry, crop=True)
out_image_B4, out_transform_B4 = mask(B04_10m_rio, gdf_municipality_box.geometry, crop=True)
out_image_B8, out_transform_B8 = mask(B08_10m_rio, gdf_municipality_box.geometry, crop=True)

In [None]:
# check the clips
fig, ax = plt.subplots(nrows=2,ncols=2, figsize=(12,12))

show(out_image_B2, transform=out_transform_B2, ax=ax[0,0], cmap='nipy_spectral', title='Band 1')
show(out_image_B3, transform=out_transform_B3, ax=ax[0,1], cmap='nipy_spectral', title='Band 2')
show(out_image_B4, transform=out_transform_B4, ax=ax[1,0], cmap='nipy_spectral', title='Band 3')
show(out_image_B8, transform=out_transform_B8, ax=ax[1,1], cmap='nipy_spectral', title='Band 8')

plt.show()

In [None]:
# I left this function out because in this case I felt simplier to work outside the class
def ndvi_function(band4, band8):
    np.seterr(divide='ignore', invalid='ignore')
    red = band4.astype('float64')
    nir = band8.astype('float64')
    ndvi = np.where(
    (nir+red)==0.,
    0,
    (nir-red)/(nir+red)
    )
    return ndvi

In [None]:
# calculatig the NDVI just for the clip
ndvi_clip = ndvi_function(out_image_B4, out_image_B8)

In [None]:
# plot the NDVI results
fig, ax = plt.subplots(figsize=(15, 15))

gdf_municipality.plot(ax=ax, facecolor='None', linewidth=1.6, edgecolor='black', zorder=2)

show(ndvi_clip, transform=out_transform_B8, ax=ax)

### Select an area to focus

If you need to focus in even more specific areas you can do so with libraries that I found online by dividing the territory into selectable areas.

In [None]:
# Get the country's shape in polygon format
country_shape = gdf_municipality.geometry.values[-1]

# Plot country
gdf_municipality.plot()
plt.axis('off');

# Print size
print('Dimension of the area is {0:.0f} x {1:.0f} m2'.format(country_shape.bounds[2] - country_shape.bounds[0],
                                                             country_shape.bounds[3] - country_shape.bounds[1]))

In [None]:
# Create the splitter to obtain a list of bboxes
bbox_splitter = UtmZoneSplitter([country_shape], gdf_municipality.crs, 1000)

bbox_list = np.array(bbox_splitter.get_bbox_list())
info_list = np.array(bbox_splitter.get_info_list())

# Prepare info of selected EOPatches
geometry = [Polygon(bbox.get_polygon()) for bbox in bbox_list]
idxs = [info['index'] for info in info_list]
idxs_x = [info['index_x'] for info in info_list]
idxs_y = [info['index_y'] for info in info_list]

gdf = gpd.GeoDataFrame({'index': idxs, 'index_x': idxs_x, 'index_y': idxs_y},
                       crs=gdf_municipality.crs,
                       geometry=geometry)

In [None]:
# select a 5x5 area (id of center patch)
ID = 170

# Obtain surrounding 5x5 patches
patchIDs = []
for idx, [bbox, info] in enumerate(zip(bbox_list, info_list)):
    if (abs(info['index_x'] - info_list[ID]['index_x']) <= 2 and
        abs(info['index_y'] - info_list[ID]['index_y']) <= 2):
        patchIDs.append(idx)

# Check if final size is 5x5
if len(patchIDs) != 5*5:
    print('Warning! Use a different central patch ID, this one is on the border.')

# Change the order of the patches (used for plotting later)
patchIDs = np.transpose(np.fliplr(np.array(patchIDs).reshape(5, 5))).ravel()

# save to shapefile
shapefile_name = 'grid_municipality.shp'
gdf.to_file(shapefile_name)

In [None]:
# figure
fig, ax = plt.subplots(figsize=(30, 30))
gdf.plot(ax=ax, facecolor='w',edgecolor='r', alpha=0.5, linewidth=1.6)
gdf_municipality.plot(ax=ax, facecolor='w', edgecolor='none',alpha=0.1, )
gdf_municipality.plot(ax=ax, facecolor='none', edgecolor='w', linewidth=1.6)

ax.set_title('Selected 5x5  tiles from {}'.format(gdf_municipality["COMUNE"][0]), fontsize=25);
for bbox, info in zip(bbox_list, info_list):
    geo = bbox.geometry
    ax.text(geo.centroid.x, geo.centroid.y, info['index'], ha='center', va='center')

gdf[gdf.index.isin(patchIDs)].plot(ax=ax,facecolor='g',edgecolor='r',alpha=0.1)

show(trueColor_img.read(), transform=trueColor_img.transform, ax=ax)

plt.axis('off');

## Open all the data for each month

This section wants to show how to use the class to process rasters sequentially to get a single file with different bands to use for example for a classification.

In [None]:
%%time

# get a list of the different period present inside the folder and count them to see how many bands there will be inside the new raster
list_dates = band.list_years
count_bands = len(list_dates)*5

# loop each period to get the interested bands
for i in list_dates:

    # use the class for each period
    band = Read_rasterio(i)

    # read all the bands available for the selected date at 10m resolution
    B02_10m_rio, B03_10m_rio, B04_10m_rio, B08_10m_rio, true_color_rio = band.read_rio_10m()

    # cut the original raster on the area of interest
    out_image_B2, out_transform_B2 = mask(B02_10m_rio, gdf_municipality_box.geometry, crop=True)
    out_image_B3, out_transform_B3 = mask(B03_10m_rio, gdf_municipality_box.geometry, crop=True)
    out_image_B4, out_transform_B4 = mask(B04_10m_rio, gdf_municipality_box.geometry, crop=True)
    out_image_B8, out_transform_B8 = mask(B08_10m_rio, gdf_municipality_box.geometry, crop=True)
    trueColor_clip, trueColor_clip_transform = mask(true_color_rio, gdf_municipality_box.geometry, crop=True)

    # generate new classes such as the NDVI to be inserted as an additional band
    ndvi_clip = ndvi_function(out_image_B4, out_image_B8)

    # get the meta information using a file generated before inside the notebook
    out_meta = clip_B02_10m.meta

    # change the meta info with our necessities
    out_meta.update({"driver": "GTiff",
                     "height": clip_B02_10m.height,
                     "width": clip_B02_10m.width,
                     "transform": out_transform_B2})

    # count for each period the bands to be saved
    count = 0
    list_file_names = ['band2.tif', 'band3.tif', 'band4.tif', 'band8.tif', 'ndvi.tif']

    # combine lists for names and clips to save them properly
    for n in [out_image_B2, out_image_B2, out_image_B2, out_image_B2, ndvi_clip]:
        with rio.open(list_file_names[count], "w", **out_meta) as dest:
                dest.write(n.reshape(n.shape[:]))

        count += 1

        # close the saved raster so they can be overwrite later
        dest.close()

    # open the generated rasters
    b2_1075 = rio.open('band2.tif')
    b3_1075 = rio.open('band3.tif')
    b4_1075 = rio.open('band4.tif')
    b8_1075 = rio.open('band8.tif')
    ndvi_clip = rio.open('ndvi.tif')

    # save  all the band in a unique file for each period in case we want to go to work for just one period
    to_save = rio.open('RGB_NIR_'+str(i).replace('.','_').replace(':','_')+'.tif', 'w', driver ='Gtiff',
                       width=b2_1075.width, height = b2_1075.height,
                       count=5,
                       crs=b2_1075.crs,
                       transform=b2_1075.transform,
                       dtype=b2_1075.dtypes[0],
                       profile=b2_1075.profile)

    # write each ban in a specific band inside the file generated before
    to_save.write(ndvi_clip.read(1), 5)
    to_save.write(b8_1075.read(1), 4)
    to_save.write(b2_1075.read(1), 3)
    to_save.write(b3_1075.read(1), 2)
    to_save.write(b4_1075.read(1), 1)

    # close all the bands to be overwrite later in the loop
    to_save.close()
    b2_1075.close()
    b3_1075.close()
    b4_1075.close()
    b8_1075.close()
    ndvi_clip.close()

    # open the file created before to be classified
    rgb_1075 = gdal.Open('RGB_NIR_'+str(i).replace('.','_').replace(':','_')+'.tif', gdal.GA_ReadOnly)

    # get the meta info of the raster and create a numpy array with its dimensions
    img = np.zeros((rgb_1075.RasterYSize, rgb_1075.RasterXSize, rgb_1075.RasterCount),
                   gdal_array.GDALTypeCodeToNumericTypeCode(rgb_1075.GetRasterBand(1).DataType))

    # fill the array
    for b in range(img.shape[2]):
        img[:, :, b] = rgb_1075.GetRasterBand(b + 1).ReadAsArray()

    # generate the new shape
    new_shape = (img.shape[0] * img.shape[1], img.shape[2])

    # adapt the shape to the sklearn way of reading to classify
    X = img[:, :, :13].reshape(new_shape)

    # use the kmeans classification
    k_means = cluster.KMeans(n_clusters=6)
    k_means.fit(X)
    X_cluster = k_means.labels_
    X_cluster = X_cluster.reshape(img[:, :, 0].shape)

    # save the classification
    ds = gdal.Open('RGB_NIR_'+str(i).replace('.','_').replace(':','_')+'.tif')
    band = ds.GetRasterBand(2)
    arr = band.ReadAsArray()
    [cols, rows] = arr.shape

    format = "GTiff"
    driver = gdal.GetDriverByName(format)

    outDataRaster = driver.Create("k_means_"+str(i).replace('.','_').replace(':','_')+'.tif', rows, cols, 1, gdal.GDT_Byte)
    outDataRaster.SetGeoTransform(ds.GetGeoTransform()) ##sets same geotransform as input
    outDataRaster.SetProjection(ds.GetProjection())##sets same projection as input+
    outDataRaster.GetRasterBand(1).WriteArray(X_cluster)

    # remove from memory
    outDataRaster.FlushCache()

In [None]:
# plot the last classification of the raster
fig, ax = plt.subplots(figsize=(15, 15))

gdf_municipality.plot(ax=ax, facecolor='None', linewidth=2, edgecolor='white')

show(X_cluster, transform=out_transform_B8, ax=ax, cmap="hsv")

## Try to reclassify

The previous code elaborates the various bands and saves everything for the period. It would also be interesting to try to have all the bands inside the same file, which however becomes heavy to process for most PCs. One possibility to reduce the timing could be to reduce the area or once the classification per period has been carried out to reclassify the various raster classified for each period together.

### First trial: classify the rasters already classified for each period together

In [None]:
# read a raster as meta base for the others
read_rio = rio.open('k_means_2019-12-23T00_00_00_000000000.tif')

In [None]:
# get all the files elaborated with the k-means algorithm
list_classified = glob("k_means*.tif")

In [None]:
# open a raster where save all the bands from the k-mean rasters
stdImage = rio.open('k_means_united.tif', 'w',
                    driver ='Gtiff',
                    width=read_rio.width, height = read_rio.height,
                    count=12,
                    crs=read_rio.crs,
                    transform=read_rio.transform,
                    dtype='float64',
                    profile=read_rio.profile)

# fill the file with all the bands
count = 1

for i in list_classified:
    read_rio = rio.open(i)
    stdImage.write(read_rio.read(1).astype('float64'), count)

    count += 1

# close the new file
stdImage.close()

In [None]:
%%time

# open the file with all the k-mean bands to be classified
img_ds = gdal.Open('k_means_united.tif', gdal.GA_ReadOnly)

img = np.zeros((img_ds.RasterYSize, img_ds.RasterXSize, img_ds.RasterCount),
gdal_array.GDALTypeCodeToNumericTypeCode(img_ds.GetRasterBand(1).DataType))
for b in range(img.shape[2]):
    img[:, :, b] = img_ds.GetRasterBand(b + 1).ReadAsArray()

new_shape = (img.shape[0] * img.shape[1], img.shape[2])

X = img[:, :, :13].reshape(new_shape)

k_means = cluster.KMeans(n_clusters=10)
k_means.fit(X)

X_cluster = k_means.labels_
X_cluster = X_cluster.reshape(img[:, :, 0].shape)

In [None]:
#check the result
plt.figure(figsize=(20,20))
plt.imshow(X_cluster, cmap="hsv")
plt.show()

### Classify all bands

In [None]:
read_rio = rio.open('RGB_NIR_2019-12-23T00_00_00_000000000.tif')

In [None]:
list_classified = glob("RGB_NIR*.tif")

In [None]:
stdImage = rio.open('RGB_NIR_united.tif', 'w',
                    driver ='Gtiff',
                    width=read_rio.width, height = read_rio.height,
                    count=len(list_classified)*5,
                    crs=read_rio.crs,
                    transform=read_rio.transform,
                    dtype='float64',
                    profile=read_rio.profile)

count = 1

for i in list_classified:
    read_rio = rio.open(i)
    for n in range(1,6):
        stdImage.write(read_rio.read(n).astype('float64'), count)

    count += 1

stdImage.close()

In [None]:
img_ds = rio.open('RGB_NIR_united.tif')

In [None]:
%%time

img_ds = gdal.Open('RGB_NIR_united.tif', gdal.GA_ReadOnly)

img = np.zeros((img_ds.RasterYSize, img_ds.RasterXSize, img_ds.RasterCount),
gdal_array.GDALTypeCodeToNumericTypeCode(img_ds.GetRasterBand(1).DataType))
for b in range(img.shape[2]):
    img[:, :, b] = img_ds.GetRasterBand(b + 1).ReadAsArray()

new_shape = (img.shape[0] * img.shape[1], img.shape[2])

X = img[:, :, :61].reshape(new_shape)

k_means = cluster.KMeans(n_clusters=4)
k_means.fit(X)

X_cluster = k_means.labels_
X_cluster = X_cluster.reshape(img[:, :, 0].shape)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(X_cluster, cmap="hsv")
plt.show()