In [None]:
import os
import numpy as np
import h5py
import gdal
from osgeo import gdal, gdalconst, osr
import time
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
!nvidia-smi

In [None]:
experiment_id = 'landsat_2017_spatio_temporal_attention_segmentation'
no_timestamps = time_steps = 10
patch_size = (32,32)
input_patch_size = 32
label_patch_size = (16,16)
labels_list = [0,1,2,3,4,5,6,7,8,9,10,11] #list of labels to use while calculating f1 score
step_size = 16
output_patch_width = 16
no_features = channels =  7
no_of_classes = 12
learning_rate = 0.0001
max_accuracy_test = 0 
diff = 8 
no_of_epochs = 100
model_folder = 'models_landsat_2017_spatio_temporal_attention_segmentation/' + experiment_id
if not os.path.exists(model_folder):
  os.makedirs(model_folder)

In [None]:
temporal_images_array = np.load(os.path.join("data/landsat_processed_data/numpy_arrays","landsat_2017_multicrop_series.npy"))
print(temporal_images_array.shape)

(10, 2976, 3712, 10)


In [None]:
## Read and Prepare labels
raster = gdal.Open(os.path.join("data/labels",  "usda_labels_2017_multicrop.tif"))
raw_label = raster.ReadAsArray()
print('Unique labels are : ', np.unique(raw_label))
# print('Count of each label: ', np.bincount(raw_label.flatten().astype(int)))

#convert raw labels to classes
label = np.zeros((2976,3712))
label[(raw_label==1) | (raw_label==12) | (raw_label==13)] = 1                               # corn               
label[raw_label==5] = 2                                                                     # soybean
label[raw_label==41] = 3                                                                    # sugarbeets
label[(raw_label==23) | (raw_label==24) | (raw_label==39)] = 4                              # wheat
label[(raw_label==42)] = 5                                                                  # drybean
label[(raw_label==36)] = 6                                                                  # Alfa Alfa
label[(raw_label==53)] = 7                                                                  # Peas
label[(raw_label==111)] = 8                                                                 # Open water
label[(raw_label==121) | (raw_label==122) | (raw_label==123) | (raw_label==124)] = 9        # Developed Areas
label[(raw_label==141) | (raw_label==142) | (raw_label==143)] = 10                          # forests
label[(raw_label==190) | (raw_label==195) | (raw_label==176) | (raw_label==152) | (raw_label==59)| (raw_label==60) | (raw_label==58) ] = 11 #Wetlands and grass

print('label shape: ',label.shape)
print('New unique labels are : ', np.unique(label))
print('Count of each label: ', np.bincount(label.flatten().astype(int)))

Unique labels are :  [  1   4   5   6  21  22  23  24  27  28  29  30  31  36  37  39  41  42
  43  44  53  58  59  60  61  68 111 121 122 123 124 131 141 142 143 152
 176 190 195 205 229]
label shape:  (2976, 3712)
New unique labels are :  [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]
Count of each label:  [  56600 4314334 3461701  466977   54804   88877  201008   25160  308932
  621749  391754 1055016]


In [None]:
testing_split = 0.5
height = temporal_images_array.shape[1]
train_image_segment = temporal_images_array[:,:int(height * (1-testing_split)), :, :7] # extract the bands of interest
train_label_segment = label[:int(height * (1-testing_split)), :]  
test_image_segment = temporal_images_array[:,int(height * (1-testing_split)):, :, :7] # extract the bands of interest
test_label_segment = label[int(height * (1-testing_split)):, :]
print('Printing shapes')
print('Train images shape: ',train_image_segment.shape,'Train labels shape: ',train_label_segment.shape)
print('Test images shape: ',test_image_segment.shape,'Test labels shape: ',test_label_segment.shape)

Printing shapes
Train images shape:  (10, 1488, 3712, 7) Train labels shape:  (1488, 3712)
Test images shape:  (10, 1488, 3712, 7) Test labels shape:  (1488, 3712)


In [None]:
# Model Architecture

class UNET_LSTM_BIDIRECTIONAL_ATTENTION(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET_LSTM_BIDIRECTIONAL_ATTENTION,self).__init__()

        self.conv1_1 = torch.nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv1_2 = torch.nn.Conv2d(64, 64, 3, padding=1)
        self.conv2_1 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.conv2_2 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.conv3_1 = torch.nn.Conv2d(128, 256, 3, padding=1)
        self.conv3_2 = torch.nn.Conv2d(256, 256, 3, padding=1)

        self.lstm = torch.nn.LSTM(256, 256, batch_first=True, bidirectional=True)
        self.attention = torch.nn.Linear(512, 1)

        self.unpool2 = torch.nn.ConvTranspose2d(512 , 128, kernel_size=2, stride=2)
        self.upconv2_1 = torch.nn.Conv2d(256, 128, 3, padding=1)
        self.upconv2_2 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.unpool1 = torch.nn.ConvTranspose2d(128 , 64, kernel_size=2, stride=2)
        self.upconv1_1 = torch.nn.Conv2d(128, 64, 3, padding=1)
        self.upconv1_2 = torch.nn.Conv2d(64, 64, 3, padding=1)

        self.out = torch.nn.Conv2d(64, out_channels, kernel_size=1, padding=0)

        self.maxpool = torch.nn.MaxPool2d(2)
        self.relu = torch.nn.ReLU(inplace=True)
        self.dropout = torch.nn.Dropout(p=0.1)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)        
    
    def crop_and_concat(self, x1, x2):
        x1_shape = x1.shape
        x2_shape = x2.shape
        offset_2, offset_3 = (x1_shape[2]-x2_shape[2])//2, (x1_shape[3]-x2_shape[3])//2
        x1_crop = x1[:, :, offset_2:offset_2+x2_shape[2], offset_3:offset_3+x2_shape[3]]
        return torch.cat([x1_crop, x2], dim=1)

    def forward(self,x):
      
        x = x.view(-1, channels, input_patch_size, input_patch_size)

        conv1 = self.relu(self.conv1_2(self.relu(self.conv1_1(x))))
        maxpool1 = self.maxpool(conv1)
        conv2 = self.relu(self.conv2_2(self.relu(self.conv2_1(maxpool1))))
        maxpool2 = self.maxpool(conv2)
        conv3 = self.relu(self.conv3_2(self.relu(self.conv3_1(maxpool2))))

        shape_enc = conv3.shape
        conv3 = conv3.view(-1, time_steps, conv3.shape[1], conv3.shape[2]*conv3.shape[3]) 
        conv3 = conv3.permute(0,3,1,2) 
        conv3 = conv3.reshape(conv3.shape[0]*conv3.shape[1], time_steps, 256)
        lstm, _ = self.lstm(conv3) 
        lstm = self.relu(lstm.reshape(-1, 512))
        attention_weights = torch.nn.functional.softmax(torch.squeeze(torch.nn.functional.avg_pool2d(self.attention(torch.tanh(lstm)).view(-1,shape_enc[2],shape_enc[3],time_steps).permute(0,3,1,2), 8)), dim=1) 
        context = torch.sum((attention_weights.view(-1, 1, 1, time_steps).repeat(1, 8, 8, 1).view(-1, 1)*lstm).view(-1, time_steps, 512), dim=1).view(-1,shape_enc[2],shape_enc[3], 512).permute(0,3,1,2)

        attention_weights_fixed = attention_weights.detach()
        unpool2 = self.unpool2(context)
        agg_conv2 = torch.sum(attention_weights_fixed.view(-1, time_steps, 1, 1, 1) * conv2.view(-1, time_steps, conv2.shape[1], conv2.shape[2], conv2.shape[3]), dim=1)
        upconv2 = self.relu(self.upconv2_2(self.relu(self.upconv2_1(self.crop_and_concat(agg_conv2, unpool2)))))
        unpool1 = self.unpool1(upconv2)
        agg_conv1 = torch.sum(attention_weights_fixed.view(-1, time_steps, 1, 1, 1) * conv1.view(-1, time_steps, conv1.shape[1], conv1.shape[2], conv1.shape[3]), dim=1)
        upconv1 = self.relu(self.upconv1_2(self.relu(self.upconv1_1(self.crop_and_concat(agg_conv1, unpool1)))))
        out = self.out(upconv1)

        return out[:,:,diff:-diff, diff:-diff]

In [None]:
# build model
model = UNET_LSTM_BIDIRECTIONAL_ATTENTION(in_channels=no_features, out_channels=no_of_classes)
model = model.to('cuda')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## train model
train_loss = []
train_accuracy = []
test_loss = []
test_accuracy = []
no_of_patches_x = int((train_image_segment.shape[1] - (patch_size[0]))/step_size)
no_of_patches_y = int((train_image_segment.shape[2] - (patch_size[1]))/step_size)
no_of_patches_x_test = int((test_image_segment.shape[1] - (patch_size[0]))/step_size)
no_of_patches_y_test = int((test_image_segment.shape[2] - (patch_size[1]))/step_size)
w = int((patch_size[0]-label_patch_size[0])/2)
image_batch = np.zeros((no_of_patches_y, ) + (no_timestamps, ) + patch_size + (no_features, ))
label_batch = np.zeros((no_of_patches_y, ) + label_patch_size)

for epoch in range(no_of_epochs):
  # Train  
  model.train()
  total_loss = 0
  accuracy = 0
  start_time = time.time()

  for x in range(no_of_patches_x):  

    for y in range(no_of_patches_y):
      
      for t in range(no_timestamps):
        image_batch[y][t] = train_image_segment[ t, x*step_size:(x*step_size) + patch_size[0], y*step_size:(y*step_size) + patch_size[1], :] 
      label_batch[y] = train_label_segment[x*step_size + w:(x*step_size) + w + label_patch_size[0], y*step_size + w:(y*step_size) + w + label_patch_size[1]]

    image_batch_tr = np.transpose(image_batch,(0,1,4,2,3)) 
    image_batch_t = torch.Tensor(image_batch_tr)
    label_batch_t = torch.Tensor(label_batch)
    
    optimizer.zero_grad()
    patch_out = model(image_batch_t.to('cuda'))
    label_batch_t = label_batch_t.type(torch.long).to('cuda')
    loss = criterion(patch_out, label_batch_t)
    loss.backward()
    optimizer.step()
    patch_out_pred =  torch.argmax(torch.nn.functional.softmax(patch_out, dim=1), dim=1)
    total_loss += loss.item()
    patch_out_pred = np.reshape(patch_out_pred.cpu().numpy(), (-1))
    label_batch_t = np.reshape(label_batch_t.cpu().numpy(), (-1))
    
    accuracy += accuracy_score(patch_out_pred,label_batch_t)
  
  print('\nEpoch {0}:\t Loss: {1}\t Accuracy: {2}\t Time: {3}'.format(epoch, total_loss/(no_of_patches_x), (accuracy/no_of_patches_x), time.time() - start_time))
  train_loss.append(total_loss/(no_of_patches_x))
  train_accuracy.append(accuracy/(no_of_patches_x))

  # Test
  model.eval()
  total_loss_test = 0
  accuracy_test = 0
  start_time = time.time()
  pred_list = []
  true_list = []

  for x in range(no_of_patches_x_test):  

    for y in range(no_of_patches_y):

      for t in range(no_timestamps):
        image_batch[y][t] = test_image_segment[ t, x*step_size:(x*step_size) + patch_size[0], y*step_size:(y*step_size) + patch_size[1], :] 
      label_batch[y] = test_label_segment[x*step_size + w:(x*step_size) + w + label_patch_size[0], y*step_size + w:(y*step_size) + w + label_patch_size[1]]
    
    image_batch_tr = np.transpose(image_batch,(0,1,4,2,3)) 
    image_batch_t = torch.Tensor(image_batch_tr)
    label_batch_t = torch.Tensor(label_batch)

    # optimizer.zero_grad()
    patch_out = model(image_batch_t.to('cuda'))
    label_batch_t = label_batch_t.type(torch.long).to('cuda')
    patch_out_pred =  torch.argmax(torch.nn.functional.softmax(patch_out, dim=1), dim=1)
    total_loss_test += loss.item()
    patch_out_pred = np.reshape(patch_out_pred.cpu().numpy(), (-1))
    label_batch_t = np.reshape(label_batch_t.cpu().numpy(), (-1))
    
    accuracy_test += accuracy_score(patch_out_pred,label_batch_t)
    pred_list.append(patch_out_pred)
    true_list.append(label_batch_t)

  pred_list_arr = np.array(pred_list).reshape(-1)
  true_list_arr = np.array(true_list).reshape(-1)
  mean_f1_score = np.mean(f1_score(true_list_arr,pred_list_arr,average = None,labels=labels_list))
  print('Test   :\t Loss: {1}\t Accuracy: {2}\t Time: {3}'.format(epoch, total_loss_test/(no_of_patches_x_test), (accuracy_test/no_of_patches_x_test), time.time() - start_time))
  print('\t\t Mean F1 Score: {}'.format(mean_f1_score))
  test_loss.append(total_loss_test/(no_of_patches_x_test))
  test_accuracy.append(accuracy_test/(no_of_patches_x_test))

  model_name = 'state_dict_epoch-'+str(epoch)+'_test_acc-' + str("{:.4f}".format(accuracy_test/(no_of_patches_x_test)))+ '_mean_f1_score-'+ str("{:.4f}".format(mean_f1_score))+'_'+str(experiment_id)+'.pt'
  torch.save(model.state_dict(), os.path.join(model_folder,  model_name))
  print('Saved model at', str(os.path.join(model_folder,  model_name)) )


In [None]:
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.plot(train_loss, label="train loss")
plt.plot(test_loss, label="test loss")
plt.legend(loc="upper right")
plt.savefig(os.path.join(model_folder, ('loss_'+experiment_id+'_pytorch.png')))
plt.show()
plt.close()

plt.title('model accuracy')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.plot(train_accuracy, label="train acc")
plt.plot(test_accuracy, label="test acc")
plt.legend(loc="lower right")
plt.savefig(os.path.join(model_folder, ('accuracy_'+experiment_id+'_pytorch.png')))
plt.show()
plt.close()