# CNN for Classification: Are SST images corrupted by clouds?

The purpose of this notebook is to classify SST (Sea Surface Temperature) images depending on whether they are corrupted by clouds, or not.

Several methods of cloud masking have been used. It is up to you to test one network per method, and then compare them.

## Import librairies

In [None]:
import numpy as np
import sys,glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import random
from tensorflow import keras
from sklearn.metrics import confusion_matrix
import itertools
from math import *

## Files load

In [None]:
# path to be modified
path = "./"

# load npy file
sst_fill_10000 = np.load(path+'SST_clouds_fill_10000.npy')
sst_fill_nan = np.load(path+'SST_clouds_fill_nan.npy')
sst_fill_mean = np.load(path+'SST_clouds_fill_mean.npy')
labels = np.load(path+'labels.npy')

## Visualisation

The cell below plots a few samples of the three datasets loaded previously. Run it.

>**QUIZZ** : 
>- How do you think these datasets were built?
>- What do the classes 0 and 1 correspond to ?

In [None]:
for i in range(10):
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(sst_fill_10000[i,:,:],vmin=np.min(sst_fill_mean[i,:,:]),vmax=np.max(sst_fill_mean[i,:,:]),origin='lower')
    # plt.colorbar()
    plt.title('SST fill 10000')
    plt.subplot(1,3,2)
    plt.imshow(sst_fill_mean[i,:,:],vmin=np.min(sst_fill_mean[i,:,:]),vmax=np.max(sst_fill_mean[i,:,:]),origin='lower')
    # plt.colorbar()
    plt.title('SST fill mean')
    plt.subplot(1,3,3)
    plt.imshow(sst_fill_nan[i,:,:],vmin=np.min(sst_fill_mean[i,:,:]),vmax=np.max(sst_fill_mean[i,:,:]),origin='lower')
    # plt.colorbar()
    plt.title('SST fill nan')
    plt.suptitle('Class : '+ str(labels[i]))

## Data preparation

You are going to build a CNN to sort SST images in classes 0 and 1. The first step is to prepare the data. Here each dataset comes in one block. Shuffling the samples before separating them into the various sets needed for training and testing is useful. Below, we suggest to work with the "10000" dataset. You can perform experiments with the others later.

>**WORK** : 
>1. Determine what is the input (X) and the output (Y).
>2. Shuffle and split the data into relevant sets (shuffle and train_test_split functions from scikit; check the doc)
>3. Reshape (toto.reshape) the data to pass it into the network 
>4. Normalize the data with the maximum value (replace $XXXX$. please don't fall in the trap)

In [None]:
# Shuffle
X,Y=shuffle(sst_fill_10000,labels)
 
# Split in train and test data
X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.33, random_state=42)

# Reshape
X_train=X_train.reshape(X_train.shape[0],X_train.shape[1],X_train.shape[2],1)
X_test=X_test.reshape(X_test.shape[0],X_test.shape[1],X_test.shape[2],1)

# Normalize
normalization_factor = 1/np.max(X_train)
X_train = XXXX
X_test = XXXX

## Model

Time to create the model! What kind of loss function should you use?

>**WORK** :
>1. Build, compile, and fit a CNN model 
>2. Visualize the evolution of the loss function and adjust some parameters if necessary

### 1. Built a CNN Model
#### Model construction with Sequential mode

In [None]:
model.summary()

### 2. Compile the model

In [None]:
# If the activation function of the output layer is sigmoid
model.compile(optimizer='adam',
              loss='XXXX',
              metrics=['accuracy'])

### 3. Fit the model

In [None]:
history=model.fit(X_train,y_train,epochs=10,
                 batch_size=32,
                 validation_split=0.2,
                 )

### 4. Visualize the evolution of the loss function

In [None]:
plt.plot(history.history['loss'],'b-',label='Train loss')
plt.plot(history.history['val_loss'], 'r-',label='Validation loss')
plt.legend()

## Prediction from the model 

>**WORK** : 
>1. Evaluate the performance of your model with model.evaluate
>2. Make a model prediction on test data
>3. Visualize the results of the model prediction
>4. Find an example for which the prediction is false

### 1. Model evaluation 

In [None]:
model.evaluate(X_test,y_test)

### 1. Make a model prediction on test data

In [None]:
y_sigmoid = model.predict(X_test)
#y_pred    = np.argmax(y_sigmoid, axis=-1)
y_pred = np.rint(y_sigmoid.squeeze())

In [None]:
y_pred.shape, y_test.shape

### 4. Visualize the results of the model prediction

In [None]:
nb_results=12
n_col=4
n_row=ceil(nb_results/n_col)
random_results = random.sample(range(y_pred.shape[0]),k=nb_results)
plt.figure(figsize=(16,9))
for i,result in enumerate(random_results):
    min_value=np.unique(np.sort(X_test[result,:,:,0]*np.max(sst_fill_10000)))[0]
    max_value=np.unique(np.sort(X_test[result,:,:,0]*np.max(sst_fill_10000)))[-2]
    plt.subplot(n_row,n_col,i+1)
    plt.imshow(X_test[result,:,:,0]*np.max(sst_fill_10000),
           vmin=min_value,vmax=max_value)
    plt.title('pred : %i , true : %i'%(y_pred[result],y_test[result]))
    plt.subplots_adjust(right=1,top=1)
    plt.colorbar()

### 3. Find an example for which the prediction is false

In [None]:
nb_false = np.where(y_pred != y_test)[0].shape[0]
nb_test = y_pred.shape[0]
print('There are %i false prediction(s) out of %i'%(nb_false,nb_test))

In [None]:
itmp = np.where(y_pred != y_test)[0][0]
plt.imshow(X_test[itmp,:,:,0]*np.max(sst_fill_10000),vmin=20, vmax=30)
plt.title('pred : %i , true : %i'%(y_pred[itmp],y_test[itmp]))
plt.colorbar()

### Bonus : confusion matrix

In [None]:
cm = confusion_matrix( y_test,y_pred, normalize=None)
    
accuracy = np.trace(cm) / float(np.sum(cm))
misclass = 1 - accuracy


cmap = plt.get_cmap('Blues')

plt.figure(figsize=(10,10))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# plt.title(title)
plt.colorbar()

# if target_names is not None:
#     tick_marks = np.arange(len(target_names))
#     plt.xticks(tick_marks, target_names, rotation=90)
#     plt.yticks(tick_marks, target_names)

cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
digit_format='{:0.2f}'

thresh = cm.max() / 1.5 
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

    plt.text(j, i, digit_format.format(cm[i, j]),
             horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
plt.xticks(ticks=[0,1])
plt.yticks(ticks=[0,1])
# save_fig(save_as)
plt.show()



### Bonus 2: Visualize the weights and bias of the model

In [None]:
nb_filter=model.get_layer("layer_conv_1").get_weights()[0].shape[-1]
nb_col=4
nb_row=ceil(nb_filter/nb_col)
fig, axes = plt.subplots(nrows=nb_row, ncols=nb_col,sharex=True,sharey=True,figsize=(16,9))
for i,ax in enumerate(axes.flat):
    # plt.subplot(nb_row,nb_col,i+1)
    im=ax.imshow(model.get_layer("layer_conv_1").get_weights()[0][:,:,0,i])
    ax.set_yticks(np.arange(model.get_layer("layer_conv_1").get_weights()[0].shape[1]))
    ax.set_xticks(np.arange(model.get_layer("layer_conv_1").get_weights()[0].shape[0]))
    ax.set_title('Bias : %f'%(model.get_layer("layer_conv_1").get_weights()[1][i]))
    # plt.colorbar()
fig.subplots_adjust(right=0.8,top=0.6)

plt.colorbar(im,ax=axes.ravel().tolist())