In [1]:
import keras
import tensorflow as tf

2023-08-30 12:17:17.721852: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-30 12:17:17.821631: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Set Keras backend for image data format

In [2]:
keras.backend.set_image_data_format('channels_first')

### Load ResNet50V2 model from Keras and freeze weights

In [3]:
base_model = tf.keras.applications.ResNet50V2(
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    input_shape=None,
    pooling=None
)

In [4]:
base_model.summary()

Model: "resnet50v2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 3, None, No  0           []                               
                                ne)]                                                              
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 3, None, Non  0           ['input_1[0][0]']                
                                e)                                                                
                                                                                                  
 conv1_conv (Conv2D)            (None, 64, None, No  9472        ['conv1_pad[0][0]']              
                                ne)                                                      

In [5]:
base_model.trainable = False

In [6]:
base_model.summary()

Model: "resnet50v2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 3, None, No  0           []                               
                                ne)]                                                              
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 3, None, Non  0           ['input_1[0][0]']                
                                e)                                                                
                                                                                                  
 conv1_conv (Conv2D)            (None, 64, None, No  9472        ['conv1_pad[0][0]']              
                                ne)                                                      

 conv2_block2_1_conv (Conv2D)   (None, 64, None, No  16384       ['conv2_block2_preact_relu[0][0]'
                                ne)                              ]                                
                                                                                                  
 conv2_block2_1_bn (BatchNormal  (None, 64, None, No  256        ['conv2_block2_1_conv[0][0]']    
 ization)                       ne)                                                               
                                                                                                  
 conv2_block2_1_relu (Activatio  (None, 64, None, No  0          ['conv2_block2_1_bn[0][0]']      
 n)                             ne)                                                               
                                                                                                  
 conv2_block2_2_pad (ZeroPaddin  (None, 64, None, No  0          ['conv2_block2_1_relu[0][0]']    
 g2D)     

### Load data and set model parameters

In [7]:
nclasses = 33 #Number of classes in training data

In [8]:
import glob

from rioxarray.exceptions import NoDataInBounds

In [9]:
rgb_paths = glob.glob('./train/RemoteSensing/RGB/*.tif')
bboxes_paths = glob.glob('./train/ITC/train_*.shp')
classes_path = './train/Field/train_data.csv'

# Load training and test data and corresponding labels
class IDTreeSDataset:
    
    def __init__(self, rgb_paths, bboxes_paths, classes_path):
        self.rgb_paths = rgb_paths
        self.bboxes_paths = bboxes_paths
        self.classes_path = classes_path
        
    def generate_cutouts(self):
        # load bboxes
        bboxes = pd.concat([
            geopandas.read_file(p)
            for p in self.bboxes_paths
        ])
        
        # load classes 
        classes = pd.read_csv(self.classes_path)
        classes = classes.set_index('indvdID')

        # load rgb data and make cutouts
        for rgb_path in self.rgb_paths:
            rgb = rioxarray.open_rasterio(rgb_path, masked=True)
            
            assert bboxes.crs == rgb.rio.crs

            # select relevant bboxes
            xmin, ymin, xmax, ymax = rgb.rio.bounds()
            bboxes_clip = bboxes.cx[xmin:xmax, ymin:ymax]
            bboxes_clip = bboxes_clip[~bboxes_clip.is_empty]
            
            for _, bbox in bboxes_clip.iterrows():
                
                indvdID = bbox['indvdID']

                if indvdID not in classes.index:
                    # issue: for some IDs, the class is not specified
                    continue
                
                label = classes.loc[indvdID]['taxonID']
                if not isinstance(label, str):
                    # issue: there are multiple entries for the same ID
                    continue
                
                try:
                    cutout = rgb.rio.clip([bbox.geometry], drop=True)
                    data, flag = _image_preprocessor(cutout.data)
                    if flag:
                        # If image is too small or too large
                        continue
                except NoDataInBounds:
                    # issue: some polygons have very small intersections with the images
                    continue
                yield indvdID, label, data
    
def _image_preprocessor(data):
    if (_remove_large_image(data) or _remove_small_images(data)):
        return data, True
    data = _image_padding(data)
    return data, False

def _remove_large_image(data):
    if max(data.shape[1:]) > 100:
        return True

def _remove_small_images(data):
    if min(data.shape[1:]) <= 32:
        return True

def _image_padding(data):
    #Pad each image 
    pad_width_x1 = np.floor((100 - data.shape[1])/2).astype(int)
    pad_width_x2 = 100 - data.shape[1] - pad_width_x1 
    pad_width_y1 = np.floor((100 - data.shape[2])/2).astype(int)
    pad_width_y2 = 100 - data.shape[2] - pad_width_y1
    data = np.pad(data, pad_width=[(0, 0),(pad_width_x1, pad_width_x2),(pad_width_y1, pad_width_y2)], mode='constant')
    return data
        

ds = IDTreeSDataset(rgb_paths, bboxes_paths, classes_path)

ds = tf.data.Dataset.from_generator(
    ds.generate_cutouts,
    output_signature=(
        tf.TensorSpec(shape=(), dtype=tf.string),
        tf.TensorSpec(shape=(), dtype=tf.string),
        tf.TensorSpec(shape=(3, None, None), dtype=tf.float32)
    )
)

### Create new model using the base model plus additional layers

In [11]:
def create_model():
    # Inputs
    inputs = keras.Input(shape=(3, 100, 100))

    # Run input image through our base model (ResNet)
    x = base_model(inputs, training = False)

    # Convert features from base model to a vector
    x = keras.layers.GlobalAveragePooling2D()(x)

    # Create additional dense layers to fine tune the model
    x = keras.layers.Dense(256)(x)
    x = keras.layers.Dense(64)(x)

    # Create output layer
    outputs = keras.layers.Dense(nclasses)(x)

    return keras.Model(inputs, outputs)

model = create_model()

In [None]:
# Compile the model
model.compile(optimizer=keras.optimizers.Adam(), 
              loss=keras.losses.CategoricalCrossentropy(from_logits=True))

In [None]:
# Train the model
model.fit(dataset, epochs=..., validation_data=...)

In [None]:
# Predict on the test set
model.predict()