In [None]:
import os
import numpy as np
import h5py
import gdal
from osgeo import gdal, gdalconst, osr
import time
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
# !nvidia-smi

In [None]:
experiment_id = 'sentinel_2A_2018_T11SKA_spatio_temporal_attention_segmentation'

quadrant_size = (5490,5490)
grid_cell_size = (1372,1372)
patch_size = (32,32)
input_patch_size = 32
label_patch_size = (16,16)
timestamps = [0,2,3,5,8,11,14,17,19,22,24,26,28,30]                             # Timestamps to consider
no_timestamps = time_steps = len(timestamps)
no_features = channels = 10                                                     # Number of Features
step_size = 16
output_patch_width = 16
no_of_grid_cells_x = 4
no_of_grid_cells_y = 4
batch_size = 16
diff = 8
total_no_of_grid_cells = no_of_grid_cells_x * no_of_grid_cells_y
grids_train = [0,2,5,7,8,10,13,15]
grids_test = [1,3,4,6,9,11,12,14]
no_of_classes = 20
labels_list = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]
learning_rate = 0.0001
max_accuracy_test = 0 
no_of_epochs = 100
quadrant = 4
data_dir = 'data/sentinel/sentinel_2A_2018_T11SKA/numpy_arrays/quadrant_wise/quadrant' + str(quadrant)
numpy_array_prefix = 'sentinel_2A_2018_T11SKA_series_quadrant4_patch'
label_dir = 'data/sentinel/sentinel_2A_2018_T11SKA/labels/quadrant_wise/quadrant' + str(quadrant)
label_array_prefix = 'sentinel_2A_2018_T11SKA_raw_label_quadrant4_patch'
model_folder = 'models_sentinel_2A_2018_T11SKA_spatio_temporal_attention_segmentation/' + experiment_id
if not os.path.exists(model_folder):
  os.makedirs(model_folder)

In [None]:
#label conversion from usda raw labels
def convert_to_label_array(raw_label,label):
  if(label.shape != raw_label.shape):
    print("Shapes not equal")
  label[(raw_label==1) | (raw_label==225) | (raw_label==226) | (raw_label==237)] = 1    # Corn                 
  label[(raw_label==2) | (raw_label==238)] = 2                                          # Cotton
  label[(raw_label==4) | (raw_label==236)] = 3                                          # Sorghum                                                                           
  label[(raw_label==22) | (raw_label==23) | (raw_label==24)] = 4                        # Wheat 
  label[(raw_label==36)] = 5                                                            # Alfa alfa
  label[(raw_label==67)] = 6                                                            # Peaches
  label[(raw_label==69)] = 7                                                            # Grapes
  label[(raw_label==71)] = 8                                                            # Tree crops
  label[(raw_label==72)] = 9                                                            # Citrus
  label[(raw_label==75)] = 10                                                           # Almonds
  label[(raw_label==76)] = 11                                                           # Walnut
  label[(raw_label==204)] = 12                                                          # Pistachio
  label[(raw_label==212)] = 13                                                          # Oranges
  label[(raw_label==218)] = 14                                                          # Nectarines
  label[(raw_label==5) | (raw_label==3) | (raw_label==27) | (raw_label==28) | (raw_label==44) | (raw_label==53) | (raw_label==21) | (raw_label==33) |(raw_label==42) | (raw_label==205)] = 15 # Misc Crops and Veg
  label[(raw_label==37) | (raw_label==58) | (raw_label==59) | (raw_label==61) | (raw_label==152) | (raw_label==176) | (raw_label==190) | (raw_label==195)] = 16                               # Wetlands and Grass
  label[(raw_label==61) | (raw_label==131)] = 17                                        # Barren/Idle land
  label[(raw_label==111)] = 18                                                          # Water
  label[(raw_label==121) | (raw_label==122) | (raw_label==123) | (raw_label==124)] = 19 # Urban
  return label


In [None]:
# Model Architecture
class UNET_LSTM_BIDIRECTIONAL_ATTENTION(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET_LSTM_BIDIRECTIONAL_ATTENTION,self).__init__()

        self.conv1_1 = torch.nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv1_2 = torch.nn.Conv2d(64, 64, 3, padding=1)
        self.conv2_1 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.conv2_2 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.conv3_1 = torch.nn.Conv2d(128, 256, 3, padding=1)
        self.conv3_2 = torch.nn.Conv2d(256, 256, 3, padding=1)

        self.lstm = torch.nn.LSTM(256, 256, batch_first=True, bidirectional=True)
        self.attention = torch.nn.Linear(512, 1)

        self.unpool2 = torch.nn.ConvTranspose2d(512 , 128, kernel_size=2, stride=2)
        self.upconv2_1 = torch.nn.Conv2d(256, 128, 3, padding=1)
        self.upconv2_2 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.unpool1 = torch.nn.ConvTranspose2d(128 , 64, kernel_size=2, stride=2)
        self.upconv1_1 = torch.nn.Conv2d(128, 64, 3, padding=1)
        self.upconv1_2 = torch.nn.Conv2d(64, 64, 3, padding=1)

        self.out = torch.nn.Conv2d(64, out_channels, kernel_size=1, padding=0)

        self.maxpool = torch.nn.MaxPool2d(2)
        self.relu = torch.nn.ReLU(inplace=True)
        self.dropout = torch.nn.Dropout(p=0.1)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)        
    
    def crop_and_concat(self, x1, x2):
        x1_shape = x1.shape
        x2_shape = x2.shape
        offset_2, offset_3 = (x1_shape[2]-x2_shape[2])//2, (x1_shape[3]-x2_shape[3])//2
        x1_crop = x1[:, :, offset_2:offset_2+x2_shape[2], offset_3:offset_3+x2_shape[3]]
        return torch.cat([x1_crop, x2], dim=1)

    def forward(self,x):
      
        x = x.view(-1, channels, input_patch_size, input_patch_size)

        conv1 = self.relu(self.conv1_2(self.relu(self.conv1_1(x))))
        maxpool1 = self.maxpool(conv1)
        conv2 = self.relu(self.conv2_2(self.relu(self.conv2_1(maxpool1))))
        maxpool2 = self.maxpool(conv2)
        conv3 = self.relu(self.conv3_2(self.relu(self.conv3_1(maxpool2))))

        shape_enc = conv3.shape 
        conv3 = conv3.view(-1, time_steps, conv3.shape[1], conv3.shape[2]*conv3.shape[3])
        conv3 = conv3.permute(0,3,1,2) 
        conv3 = conv3.reshape(conv3.shape[0]*conv3.shape[1], time_steps, 256) 
        lstm, _ = self.lstm(conv3) 
        lstm = self.relu(lstm.reshape(-1, 512)) 
        attention_weights = torch.nn.functional.softmax(torch.squeeze(torch.nn.functional.avg_pool2d(self.attention(torch.tanh(lstm)).view(-1,shape_enc[2],shape_enc[3],time_steps).permute(0,3,1,2), 8)), dim=1)
        context = torch.sum((attention_weights.view(-1, 1, 1, time_steps).repeat(1, 8, 8, 1).view(-1, 1)*lstm).view(-1, time_steps, 512), dim=1).view(-1,shape_enc[2],shape_enc[3], 512).permute(0,3,1,2) 

        attention_weights_fixed = attention_weights.detach()
        unpool2 = self.unpool2(context)
        agg_conv2 = torch.sum(attention_weights_fixed.view(-1, time_steps, 1, 1, 1) * conv2.view(-1, time_steps, conv2.shape[1], conv2.shape[2], conv2.shape[3]), dim=1)
        upconv2 = self.relu(self.upconv2_2(self.relu(self.upconv2_1(self.crop_and_concat(agg_conv2, unpool2)))))
        unpool1 = self.unpool1(upconv2)
        agg_conv1 = torch.sum(attention_weights_fixed.view(-1, time_steps, 1, 1, 1) * conv1.view(-1, time_steps, conv1.shape[1], conv1.shape[2], conv1.shape[3]), dim=1)
        upconv1 = self.relu(self.upconv1_2(self.relu(self.upconv1_1(self.crop_and_concat(agg_conv1, unpool1)))))
        out = self.out(upconv1)

        return out[:,:,diff:-diff, diff:-diff]

In [None]:
# build model
model = UNET_LSTM_BIDIRECTIONAL_ATTENTION(in_channels=no_features, out_channels=no_of_classes)
model = model.to('cuda')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## train model
train_loss = []
train_accuracy = []
test_loss = []
test_accuracy = []
no_of_patches_x = int((grid_cell_size[0] - (patch_size[0]))/step_size)
no_of_patches_y = int((grid_cell_size[1] - (patch_size[1]))/step_size)
no_of_patches_x_test = int((grid_cell_size[0] - (patch_size[0]))/step_size)
no_of_patches_y_test = int((grid_cell_size[1] - (patch_size[1]))/step_size)
w = int((patch_size[0]-label_patch_size[0])/2)
no_of_batches_y = int(no_of_patches_y/batch_size)
no_of_batches_y_test = int(no_of_patches_y_test/batch_size)
image_batch = np.zeros((batch_size, ) + (no_timestamps, ) + patch_size + (no_features, ))
label_batch = np.zeros((batch_size, ) + label_patch_size)
print(no_of_patches_x,no_of_batches_y, image_batch.shape)

for epoch in range(no_of_epochs):
  print('\n##  EPOCH ',epoch,'  ##')
  
  # Train  
  print('\tTraining')
  model.train()
  total_loss = 0
  accuracy_grids = 0
  start_time = time.time()

  for grid in grids_train:

    start_grid_time = time.time()
    accuracy_rows = 0
    train_grid_cell = np.load(os.path.join(data_dir,numpy_array_prefix + str(grid) + '.npy'))
    train_grid_cell_raw_label = np.load(os.path.join(label_dir,label_array_prefix + str(grid) + '.npy'))
    train_grid_cell_label = np.zeros((grid_cell_size))
    train_grid_cell_label = convert_to_label_array(train_grid_cell_raw_label,train_grid_cell_label)

    for x in range(no_of_patches_x): 

      accuracy = 0

      for y in range(no_of_batches_y):

        for b in range(batch_size):
          image_batch[b] = train_grid_cell[timestamps, x*step_size:(x*step_size) + patch_size[0], (((y*batch_size)+b)*step_size):(((y*batch_size)+b)*step_size) + patch_size[1], :] 
          label_batch[b] = train_grid_cell_label[x*step_size + w:(x*step_size) + w + label_patch_size[0], (((y*batch_size)+b)*step_size) + w:(((y*batch_size)+b)*step_size) + w + label_patch_size[1]]

        image_batch_tr = np.transpose(image_batch,(0,1,4,2,3)) 
        image_batch_t = torch.Tensor(image_batch_tr)
        label_batch_t = torch.Tensor(label_batch)
        
        optimizer.zero_grad()
        patch_out = model(image_batch_t.to('cuda'))
        label_batch_t = label_batch_t.type(torch.long).to('cuda')
        loss = criterion(patch_out, label_batch_t)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        patch_out_pred =  torch.argmax(torch.nn.functional.softmax(patch_out, dim=1), dim=1)
        patch_out_pred = np.reshape(patch_out_pred.cpu().numpy(), (-1))
        label_batch_t = np.reshape(label_batch_t.cpu().numpy(), (-1))
        accuracy += accuracy_score(patch_out_pred,label_batch_t)

      accuracy_rows += accuracy/no_of_batches_y

    print('\t\tGrid no: ', grid, '\tAccuracy grid: ',accuracy_rows/no_of_patches_x, '\t Time taken for Grid: ', time.time() - start_grid_time)
    accuracy_grids += accuracy_rows/no_of_patches_x

  print('\n\tTrain:\t Loss: {1}\t Accuracy: {2}\t Time: {3}'.format(epoch, total_loss/(no_of_patches_x), (accuracy_grids/len(grids_train)), time.time() - start_time))
  train_loss.append(total_loss/(no_of_patches_x))
  train_accuracy.append(accuracy_grids/len(grids_train))

  # Test
  model.eval()
  print('\n\tTesting')
  total_loss_test = 0
  accuracy_test_grids = 0
  start_time_test = time.time()
  pred_list = []
  true_list = []

  for grid in grids_test:

    start_grid_time_test = time.time()
    accuracy_test_row = 0
    test_grid_cell = np.load(os.path.join(data_dir,numpy_array_prefix + str(grid) + '.npy'))
    test_grid_cell_raw_label = np.load(os.path.join(label_dir,label_array_prefix + str(grid) + '.npy'))
    test_grid_cell_label = np.zeros((grid_cell_size))
    test_grid_cell_label = convert_to_label_array(test_grid_cell_raw_label,test_grid_cell_label)

    for x in range(no_of_patches_x_test):

      accuracy_test = 0
      # print(str(x)+str('-'),end = '')

      for y in range(no_of_batches_y_test):

        for b in range(batch_size):
          image_batch[b] = test_grid_cell[timestamps, x*step_size:(x*step_size) + patch_size[0], (((y*batch_size)+b)*step_size):(((y*batch_size)+b)*step_size) + patch_size[1], :] 
          label_batch[b] = test_grid_cell_label[x*step_size + w:(x*step_size) + w + label_patch_size[0], (((y*batch_size)+b)*step_size) + w:(((y*batch_size)+b)*step_size) + w + label_patch_size[1]]
      
        image_batch_tr = np.transpose(image_batch,(0,1,4,2,3)) 
        image_batch_t = torch.Tensor(image_batch_tr)
        label_batch_t = torch.Tensor(label_batch)

        patch_out = model(image_batch_t.to('cuda'))
        label_batch_t = label_batch_t.type(torch.long).to('cuda')
        loss_test = criterion(patch_out, label_batch_t)
        patch_out_pred =  torch.argmax(torch.nn.functional.softmax(patch_out, dim=1), dim=1)
        total_loss_test += loss_test.item()
        patch_out_pred = np.reshape(patch_out_pred.cpu().numpy(), (-1))
        label_batch_t = np.reshape(label_batch_t.cpu().numpy(), (-1))
        accuracy_test += accuracy_score(patch_out_pred,label_batch_t)
        pred_list.append(patch_out_pred)
        true_list.append(label_batch_t)

      accuracy_test_row += accuracy_test/no_of_batches_y_test

    print('\t\tGrid no: ', grid, '\tAccuracy grid: ',accuracy_test_row/no_of_patches_x_test, '\t Time taken for Grid: ', time.time() - start_grid_time_test)
    accuracy_test_grids += accuracy_test_row/no_of_patches_x_test

  pred_list_arr = np.array(pred_list).reshape(-1)
  true_list_arr = np.array(true_list).reshape(-1)
  mean_f1_score = np.mean(f1_score(true_list_arr,pred_list_arr,average = None,labels=labels_list))
  print('\tTest:\t Loss: {}\t Accuracy: {}\t Time: {}'.format(total_loss_test/(no_of_patches_x_test), (accuracy_test_grids/len(grids_test)), time.time() - start_time_test))
  print('\t\t\t Mean F1 Score: {}'.format(mean_f1_score))
  test_loss.append(total_loss_test/(no_of_patches_x_test))
  test_accuracy.append(accuracy_test_grids/len(grids_test))

  model_name = 'state_dict_epoch-'+str(epoch)+'_test_acc-' + str("{:.4f}".format(accuracy_test_grids/len(grids_test))) + '_mean_f1_score-'+ str("{:.4f}".format(mean_f1_score))+'_'+str(experiment_id)+'.pt'
  torch.save(model.state_dict(), os.path.join(model_folder,  model_name))
  print('Saved model at', str(os.path.join(model_folder,  model_name)) )


In [None]:
# Plot graphs
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.plot(train_loss, label="train loss")
plt.plot(test_loss, label="test loss")
plt.legend(loc="upper right")
plt.savefig(os.path.join(model_folder, ('loss_'+experiment_id+'_pytorch.png')))
plt.show()
plt.close()

plt.title('model accuracy')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.plot(train_accuracy, label="train acc")
plt.plot(test_accuracy, label="test acc")
plt.legend(loc="lower right")
plt.savefig(os.path.join(model_folder, ('accuracy_'+experiment_id+'_pytorch.png')))
plt.show()
plt.close()