# Intro to Digital Agriculture
## Week 6
Instructors: Maria  Pukalchik, Dmitry Shadrin

TAs for this week: Svetlana Illarionova, Ivan Matvienko

On this practical we will work with the satellite images, look for agricultural fields and visualize them

# Part 0
### Data download and preparation
Lets first download the data (can take several minutes)

In [None]:
!wget --no-check-certificate "https://onedrive.live.com/download?cid=1E2DE865E90D4259&resid=1E2DE865E90D4259%21258622&authkey=AL59NvN10qXgesk" -O Farmpins.zip
!wget --no-check-certificate "https://onedrive.live.com/download?cid=1E2DE865E90D4259&resid=1E2DE865E90D4259%21195470&authkey=AFPw5W-8uzm5vpM" -O unet_parts.py

In [None]:
!unzip ./Farmpins.zip

And import necessary libraries

In [4]:
import numpy as np
import imageio as io
import pandas as pd
import skimage.transform as transforms
import skimage.util as utils
import os
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam, lr_scheduler
import gdal
from unet_parts import *
import torchvision
from tqdm import tnrange, tqdm
from IPython.display import clear_output

In [5]:
from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, f1_score, precision_score, classification_report, SCORERS

# Part 1. EDA

Lets load the images and look what we have inside

In [6]:
data_f = gdal.Open('./20170101_mosaic_cropped.tif')
data_mask_train = gdal.Open('./train_crops.tif')

`fields` - data tensor (satellite image) with surface reflectance values

`train_mask` - crop labels for field (used for training)

In [7]:
fields = np.array(data_f.ReadAsArray())
train_mask = np.array(data_mask_train.ReadAsArray())

In [8]:
train_id = gdal.Open('./train_field_id.tif')
train_field_id_map = np.array(train_id.ReadAsArray()).astype(np.uint16)
train_field_id_map = np.hstack([train_field_id_map, np.zeros((train_field_id_map.shape[0], 1))])
train_field_id_map = train_field_id_map*(train_mask > 0)
# train_field_id = np.unique(train_field_id_map)
# train_field_id = train_field_id[train_field_id > 0]

Lets plot `train_field_id_map`

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(train_field_id_map)
# plt.colorbar()

As we can see, `train_field_id_map` contains masks with field ids (pixels corresponding to some particular field #$n$ will have all value $n$)

In [10]:
train_mask[train_mask > 15] = 0
np.unique(train_mask)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8)

We can also see, that it has only 10 unique values, thus it is most probably (and really is) crop types (out labels for field). Lets visualize them.

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(train_mask)

Now it is time to look at satellite image itself. Look at image shape. It should have 13 channels.

In [12]:
fields.shape

(13, 7467, 8292)

Since we are working with the Sentinel - it captures multispectral images: in our case it has 13 channels. In order to plot the RGB or any other image of the planet surface we should Google tech sheet for Sentinel 2B satellite and search for channels description :[Sentinel 2 Wiki Page](https://en.wikipedia.org/wiki/Sentinel-2#Instruments)

Lets plot the NIR channel (7)

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(fields[7, :2000, :3000], cmap='gray')

One can play and plot different channels and convince yourself that different channels have different spatial resolution. Sentinel have 3 types of resolutions (60m, 20m and 10m per pix). They all have been upsampled up to 10m per pixel, but visually one can notice the difference

Lets visualize agro fields on the image

In [None]:
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

#here we define the color scheme for our filed, such that every crop should have
#its own constant color
viridis = cm.get_cmap('viridis', 256)
train_colors = viridis(np.linspace(0, 1, 256))
transp = np.array([0, 0, 0, 0])
train_colors[:25, :] = transp
train_cmp = ListedColormap(train_colors)

plt.figure(figsize=(15, 15))
plt.imshow(fields[12, :2000, :3000], cmap='gray')
plt.imshow(train_mask[:2000, :3000], alpha=0.5, cmap=train_cmp)

**Task 1 (1 point for each index)**. Plot the NDVI, EVI, NDRE indices for patch of original image for coordinates(pixels) in range for x in [0, 3000] and y in [500:1500]. Just Google the indices, if you do not know exact formulae.

 *Hint: mind the proper order when slicing the tensor*

Since our dataset is imbalanced, training of even simplest segmentation model could be quite difficult. One of the simplest ways to fight with this issue is to assign the weight for each class in a loss function.

**Task 2 (2 points)**. Calculate the number of *pixels* (1 pt) and *fields* (1 pt) for each crop and print them out. This information will help us further in training loop

# Part 2. Download data for training

In order to train the network, we need to split our big satellite image into patches in a such way, that on the one hand, we could process it using neural networks, but on the other hand we could split it on training/validation set with equal contribution of all classes.

One way to do this is to crop each field from the original image and resize them (for classification) or not (for segmentation). Another way is to just use patches with constant size. In this work we will use second approach, even though it is much harder to properly split it with equal class proportions. One can try to do that for additional points, but I suggest you to download prepared already split dataset with the link below.


In [None]:
!wget --no-check-certificate "https://onedrive.live.com/download?cid=1E2DE865E90D4259&resid=1E2DE865E90D4259%21258624&authkey=AHWYbEgRMryIkKQ" -O Patches.zip
!unzip ./Patches.zip

The main task is to train networks for classification using fully convolutional network like UNet with Squeeze and Excitation blocks.

Downloaded data have the following structure: it has the `train` and `val` folder with the patches for *train* and *validation* purpouses. Each folder contains 2 folders: `images` with the multispectral image patches and `labels` of corresponding crops. 

Bellow the structure of a network is defined. Ones who are interested can look through it and investigate it. It is basically the UNet with channel-wise attention mechanism called Squeeze and Excitation (SE_block here)

In [16]:
class SE_block(nn.Module):
    def __init__(self, channels, squeese_rate=1):
        super(SE_block, self).__init__()
        self.fc1 = nn.Linear(channels, channels//squeese_rate)
        self.fc2 = nn.Linear(channels//squeese_rate, channels)
    
    def forward(self, input):
        g_avg = torch.mean(input, [-1, -2])
        x = self.fc1(g_avg).relu()
        x = self.fc2(x).sigmoid()
        return torch.unsqueeze(torch.unsqueeze(x, 2), 3)*input

In [17]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.se1 = SE_block(128)
        self.down2 = down(128, 256)
        self.se2 = SE_block(256)
        self.down3 = down(256, 256) #was 512
        self.se3 = SE_block(256) #was 512
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256, bilinear=False)
        self.up2 = up(512, 128, bilinear=False)
        self.up3 = up(256, 64, bilinear=False)
        self.up4 = up(128, 64, bilinear=False)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(self.se1(x2))
        x4 = self.down3(self.se2(x3))
        # x5 = self.down4(self.se3(x4))
        # x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

Since we have relatively low amount of data, we should use augmentation in order to prevent overfitting and increase generalization a little bit.

Below are the parameters for augmentation

In [18]:
img_size = 100
pad_size = img_size//2
shift = img_size//5
rot_angle = 60

It is also important to normalize the data, so do it.


**Task 3 (1 point)** Calculate the *mean* and *standart deviation* for each channel in the input image.

In [19]:
'''
    mean, std: torch.tensor or nd.array of shape (channels)
'''
# Your code here
# mean = ...
# std = ...

**Task 4 (2 points)** Below you can see the bodies for the functions to additional features represented by different vegetation indices: we suggest you to fill in these functions with the code for generation features, but you can also find some additional features and implement them here in a similar fashion

In [20]:
def generate_NDVI(features):
    '''
    Arguments:
        features: torch.tensor of shape (Chanels, H, W)
    Return:
        ndvi: torch.tensor of shape (H, W)
    '''
    # Your code here
    # ndvi = ...
    return ndvi

In [21]:
def generate_EVI(features):
    '''
    Arguments:
        features: torch.tensor of shape (Chanels, H, W)
    Return:
        evi: torch.tensor of shape (H, W)
    '''
    # Your code here
    # evi = ...
    return evi

In [22]:
def generate_NDRE(features):
    '''
    Arguments:
        features: torch.tensor of shape (Chanels, H, W)
    Return:
        ndre: torch.tensor of shape (H, W)
    '''
    # Your code here
    # ndre = ...
    return ndre

In [23]:
def generate_MSAVI(features):
    '''
    Arguments:
        features: torch.tensor of shape (Chanels, H, W)
    Return:
        msavi: torch.tensor of shape (H, W)
    '''
    # Your code here
    # msavi = ...
    return msavi

Here we gather all generated features in one tensor

In [24]:
def generate_all_indices(features, u2b4evi=True):
    ndvi = generate_NDVI(features)
    evi = generate_EVI(features, u2b4evi)
    ndre = generate_NDRE(features)
    msavi = generate_MSAVI(features)
    return torch.cat([features, ndvi.unsqueeze(0), evi.unsqueeze(0), ndre.unsqueeze(0), msavi.unsqueeze(0)], dim=0)

This is the class for `torch` `Dataset` which loads the data from folders and do the preprocessing after loading

We have computed `mean` and `std` in previous task, thus we should apply it. Keep in mind that one should apply normalization after all the features are generated and concatenated to the original features.

**Task 4 (1 point)** Using function `generate_all_features()` you are suggested to calculate all indices and normalize the satellite image. 

*Hint: It is recommended to normalize only satellite image part and leave generated indices as it is*

In [26]:
class CropFieldsDataset(Dataset):

    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        i_list = os.listdir(self.images_dir)
        self.im_list = []
        for image in i_list:
            if (image[-3:] == 'npy'):
                self.im_list.append(image)
                
        self.mean = mean[:, None, None]
        self.std = std[:, None, None]

    def __len__(self):
        return len(self.im_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.im_list[idx])
        image = np.load(img_name)#*255
        labels_name = os.path.join(self.labels_dir, self.im_list[idx])
        labels = np.load(labels_name)
        if(self.transform != None):

            randangle = np.random.randint(-rot_angle, rot_angle)
            sc = tuple(np.random.uniform(0.75, 1, 2))
            tf = transforms.AffineTransform(scale = sc)
            image = transforms.rotate(image, randangle, mode='reflect')
            image = utils.pad(image, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')
            image = transforms.warp(image, tf, mode='reflect')
            labels = transforms.rotate(labels, randangle, mode='reflect', order=0, preserve_range=True)
            labels = utils.pad(labels, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')
            labels = transforms.warp(labels, tf, mode='reflect', order=0)
                        
            M, N, D = image.shape
            randshiftx = np.random.randint(-shift, shift)
            randshifty = np.random.randint(-shift, shift)
            image = utils.crop(image, (((M + randshiftx - img_size)//2, (M - randshiftx + 1 - img_size)//2), 
                                       ((N + randshifty - img_size)//2, (N - randshifty + 1 - img_size)//2), (0, 0)))
            labels= utils.crop(labels,(((M + randshiftx - img_size)//2, (M - randshiftx + 1 - img_size)//2), 
                                       ((N + randshifty - img_size)//2, (N - randshifty + 1 - img_size)//2)))           
            
            
        image = torchvision.transforms.ToTensor()(image)
        labels = torch.tensor(labels).to(torch.long)

        ## Your code here 
        # image = ...
        # norm_image = ...
        ## ------------------
        return norm_image, labels

Lets make the instance for `Dataset` and make the dataloade (one could try different batch sizes which is appropriate for GPU you use)

In [27]:
train_set = CropFieldsDataset("./train/images", "./train/labels", transform=True)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)

And lets visualize the how augmentation works

In [None]:
image, labels = train_set[10]

_, ax = plt.subplots(ncols=2)
ax[0].imshow(image[12, :, :], cmap='gray')
ax[1].imshow(labels[:, :])
image.shape

Here comes the validation dataset also

In [29]:
val_set = CropFieldsDataset("./val/images", "./val/labels")
val_loader = DataLoader(val_set, batch_size=1)

One could look for the expected result, which is generated from validation dataset

In [30]:
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

viridis = cm.get_cmap('viridis', 256)
train_colors = viridis(np.linspace(0, 1, 256))
test_colors = viridis(np.linspace(0, 1, 256))
transp = np.array([0, 0, 0, 0])
train_colors[:25, :] = transp
train_cmp = ListedColormap(train_colors)

In [None]:
# len(train_set)
image, labels = val_set[18]

print(image.shape)
_, ax = plt.subplots(ncols=2, figsize=(10, 10), dpi=100)
ax[0].imshow(image[2, :, :], cmap='gray')
ax[1].imshow(image[2, :, :], cmap='gray')
ax[1].imshow(labels[:, :], cmap=train_cmp)
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[0].set_title('Band 3')
ax[1].set_title('Crop map')

Now we are going to train the network

In [33]:
if(torch.cuda.is_available()):
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

# Here we initialize the network: UNet(input_features, n_classes)
# If you generate additional features, you should put here number of channels
# of your final input (13 + number_of_generated_indices)
model = UNet(17, 10).to(device)

In [34]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

**Task 5 (1 point)** Complete training loop with missing parts (look at the comments for guidance) and 

**Task 6 (2 points)** Train your network. You will get full score (**2 points**) if the $accuracy$ on validation $\geq 63\%$, **1 point** if it will be in range $50\% \leq accuracy < 63\%$, and **no points** if $accuracy < 50\%$

In [None]:
train_weights = torch.tensor([0, 0.14852941, 0, 0.28836312, 0.04097108, 0.07013557, 0.10067904, 0.07735483, 0.02633347, 0]).to(device)
val_weights = torch.tensor([0, 0.20184716, 0, 0.23679423, 0.03584126, 0.06965259, 0.09625471, 0.07470775, 0.02500757, 0]).to(device)
optimizer = Adam(model.parameters(), lr = 0.01, weight_decay=0.0001)

train_criterion = nn.CrossEntropyLoss(ignore_index=0, weight=train_weights, reduction='mean').to(device)
val_criterion = nn.CrossEntropyLoss(ignore_index=0, weight=val_weights, reduction='mean').to(device)

lr_sch = lr_scheduler.StepLR(optimizer, 1, gamma=0.977)
epochs_num = 350

loss_list = []
acc_list = []

best_acc = 0.0 #0.628

for epoch in range(epochs_num):  # loop over the dataset multiple times

    running_train_loss = 0.0
    train_accuracy_pix = 0
    train_sum_pix = 0
    model.train()
    for data in tqdm(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # transfer data to the device
        # inputs = ...
        # labels = ...
        # zero the parameter gradients
        # ...
        # make forward pass, compute the loss, make backward pass and optimization step
        # outputs = ...
        # loss = ...
        # ...
        # ...

        # print statistics
        running_train_loss += loss.item()
        
        result = F.softmax(outputs, dim=1).detach().cpu().numpy()
        pred = np.argmax(result, axis=1)
        labels = labels.cpu().numpy()
        fl_labels = labels.flatten()
        train_accuracy_pix += np.sum(pred.flatten()[fl_labels > 0] == fl_labels[fl_labels > 0])
        train_sum_pix += np.sum(labels > 0)
#     train_writer.add_scalar('Loss', running_train_loss, global_step=epoch)

    running_train_loss /= len(train_set)
    loss_list.append(running_train_loss)
    acc_list.append(train_accuracy_pix/train_sum_pix)
    
    model.eval()
    running_val_loss = 0.0
    val_accuracy_pix = 0
    val_sum_pix = 0
    
    for data in tqdm(val_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            loss = val_criterion(outputs, labels)
        
        running_val_loss += loss.item()
        
        result = F.softmax(outputs, dim=1).detach().cpu().numpy()
        pred = np.argmax(result, axis=1)
        labels = labels.cpu().numpy()
        fl_labels = labels.flatten()
        val_accuracy_pix += np.sum(pred.flatten()[fl_labels > 0] == fl_labels[fl_labels > 0])
        val_sum_pix += np.sum(labels > 0)
        
    running_val_loss /= len(val_set)

    # train_writer.add_scalar('Accuracy', train_accuracy_pix/train_sum_pix, global_step=epoch)
    # train_writer.add_scalar("Loss",  running_train_loss, global_step=epoch)
    
    # val_writer.add_scalar('Accuracy', val_accuracy_pix/val_sum_pix, global_step=epoch)
    # val_writer.add_scalar("Loss", running_val_loss, global_step=epoch)
    print("Epoch: {0:3d} | LR: {1}\n    Train Loss:  {2:4f} \n    Train Acc: \t {3:4.1f}%\n    Val Loss:    {4:4f} \n    Val Acc: \t {5:4.1f}%".format(
                epoch+1, get_lr(optimizer), 
                running_train_loss, 
                100*train_accuracy_pix/train_sum_pix,
                running_val_loss,
                100*val_accuracy_pix/val_sum_pix))
    
    if (val_accuracy_pix/val_sum_pix > best_acc):
        best_acc = val_accuracy_pix/val_sum_pix
        torch.save(model.state_dict(), 'best_model.pth')
    
    lr_sch.step()
        
print('Finished Training')

Next cells are just for validation




In [None]:
model.eval()
val_accuracy_pix = 0
val_sum_pix = 0

correct_class_pix = np.zeros((10), dtype=np.int)
class_pix = np.zeros((10), dtype=np.int)

preds_list = []
labels_list = []

for data in tqdm(val_loader):
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.no_grad():
        outputs = model(inputs)
        
    result = F.softmax(outputs, dim=1).detach().cpu().numpy()
    pred = np.argmax(result, axis=1)
    preds_list.append(pred.flatten())
    labels = labels.cpu().numpy()
    fl_labels = labels.flatten()
    labels_list.append(fl_labels)
    val_accuracy_pix += np.sum(pred.flatten()[fl_labels > 0] == fl_labels[fl_labels > 0])
    val_sum_pix += np.sum(labels > 0)
    for i in range(10):
        class_pix[i] += np.sum(labels == i)
        correct_class_pix[i] += np.sum((pred == i)*(labels == i))

class_acc = correct_class_pix/class_pix

pix_df = pd.DataFrame(data=np.vstack([correct_class_pix, class_pix]), 
                      index=['correct_pix', 'pix'], 
                      columns=np.arange(10))
acc_df = pd.DataFrame(data=np.round(class_acc[None, :], 3), index=['accuracy'], columns=np.arange(10))

print('Accuracy', val_accuracy_pix/val_sum_pix)

display(pix_df)
display(acc_df)

In [41]:
all_preds = np.hstack(preds_list)
all_labels = np.hstack(labels_list)

Cell below are for visualization of correctness of the models prediction (green means correct, red - incorrect). Just launch all remaining cells

In [None]:
images = []
results = []
llabels = []
with torch.no_grad():
    for i, data in tqdm(enumerate(val_loader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        images.append(inputs.numpy())
        llabels.append(labels.numpy())
        inputs = inputs.to(device)

        # forward
        outputs = model(inputs)
        result = F.softmax(outputs, dim=1).detach().cpu().numpy()
        pred = np.argmax(result, axis=1)
        results.append(pred)

In [45]:
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

viridis = cm.get_cmap('viridis', 256)
train_colors = viridis(np.linspace(0, 1, 256))
test_colors = viridis(np.linspace(0, 1, 256))
transp = np.array([0, 0, 0, 0])
train_colors[:24, :] = transp
train_cmp = ListedColormap(train_colors)
pink = np.array([248/256, 24/256, 148/256, 1])
test_colors[:128, :] = transp
test_colors[128:, :] = pink
test_cmp = ListedColormap(test_colors)

In [46]:
test_scale = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

In [None]:
for i in range(len(val_set)):
    right = (results[i][0, ...] == llabels[i][0, ...]).astype(int) - (results[i][0, ...] != llabels[i][0, ...]).astype(int)
    pic1 = results[i][0, ...]*(llabels[i][0, ...] > 0)
    pic2 = llabels[i][0, ...]#.numpy()
    pic1[0, :10] = test_scale
    pic2[0, :10] = test_scale

    _, ax = plt.subplots(ncols=3, figsize=(12, 4))
    ax[0].imshow(images[i][0, 3, ...], cmap='gray')
    fig1 = ax[0].imshow(pic1, alpha=1, cmap=train_cmp)
    ax[0].set_title('Model Prediction')
    fig2 = ax[1].imshow(pic2, cmap=viridis)
    ax[1].set_title('Ground Truth')
    ax[2].imshow(right*(llabels[i][0, ...] > 0), cmap='RdYlGn')
    ax[2].set_title('Correctness')
    # plt.colorbar(fig1, ax=ax[0])
    # plt.colorbar(fig2, ax=ax[1])
    ax[0].xaxis.set_visible(False)
    ax[0].yaxis.set_visible(False)
    ax[1].xaxis.set_visible(False)
    ax[1].yaxis.set_visible(False)
    ax[2].xaxis.set_visible(False)
    ax[2].yaxis.set_visible(False)
    
#     plt.savefig('Figures/fig_{}.jpg'.format(i), dpi=100)