# Import libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras import Sequential,layers
from tensorflow.keras.layers import Conv2D,BatchNormalization, MaxPool2D, Activation, Flatten, Dense, GlobalAveragePooling2D, GlobalMaxPool2D, AveragePooling2D, Lambda, Reshape, UpSampling2D, Conv2DTranspose 
from tensorflow.keras import regularizers
from tensorflow.keras.models import Model
import cv2
import os
import numpy as np
from tensorflow.keras.layers import Conv3D,BatchNormalization, MaxPool3D, Activation, Flatten, Dense, GlobalAveragePooling3D, GlobalMaxPool3D, AveragePooling3D, Lambda

import random
import scipy.io as sio
import numpy as np
import datetime
import sklearn
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
import matplotlib

# Dataset Preprocess

In [None]:
def applyPCA(X, numComponents=75):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

def padWithZeros(X, margin=2):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize, removeZeroLabels = True):
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)  # X :(145, 145, 30) --> (195, 195, 30) with window =25
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))  # (21025, 25, 25, 30)   
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))  # (21025,)
    patchIndex = 0
    
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]  
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]            
            patchIndex = patchIndex + 1
  
    patchesData = np.expand_dims(patchesData, axis=-1)
    return patchesData,patchesLabels

def patches_class(X,Y,n):
    n_classes = n
    patches_list = []
    labeles_list = []
    for i in range(1,n_classes+1):   # not considering class 0
        patchesData_Ith_Label = X[Y==i,:,:,:,:]
        Ith_Label = Y[Y==i]
        patches_list.append(patchesData_Ith_Label)
        labeles_list.append(Ith_Label)
        
    return patches_list,labeles_list

In [None]:
windowSize = 11
im_height, im_width, im_depth, im_channel = windowSize, windowSize, 30, 1 

In [None]:
X = sio.loadmat('/content/drive/MyDrive/Indian_pines_corrected.mat')['indian_pines_corrected']
y = sio.loadmat('/content/drive/MyDrive/Indian_pines_gt.mat')['indian_pines_gt']
print(X.shape, y.shape)

X,pca = applyPCA(X,numComponents=im_depth)
print(X.shape, y.shape)

X, y = createImageCubes(X, y, windowSize)
print(X.shape, y.shape)

patches_class_ip,label_ip = patches_class(X,y,16) # class_wise list of patches #(16,) for class 0: (2009, 9, 9, 20, 1)
patches_class_ip[0].shape[0]

#### Dataset split

In [None]:
train_class_indices = [1,2,4,5,7,9,10,11,13,14]    # 10 classes  
train_class_labels = [2,3,5,6,8,10,11,12,14,15]

test_class_indices = [0,3,6,8,12,15]               # 6 classes
test_class_labels = [1,4,7,9,13,16]

### Data Loading

In [None]:
# replace=False - no repeat
def new_episode(patches_list,NS,NQ,CS,CQ,class_labels) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = list(np.random.choice(class_labels,CQ,replace=False))  # Randomly choice 6 Query Classes
    support_classes = list(np.random.choice(selected_classes,CS,replace=False))  # Randomly choice 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes    

# Define Model

In [None]:
class Channel_Attention_3D(tf.keras.layers.Layer) :
  def __init__(self,C,ratio) :
    super(Channel_Attention_3D,self).__init__()
    self.avg_pool = GlobalAveragePooling3D()
    self.max_pool = GlobalMaxPool3D()
    self.activation = Activation('sigmoid')
    self.fc1 = Dense(C/ratio, activation = 'relu')
    self.fc2 = Dense(C)
  def call(self,x) :
    avg_out1 = self.avg_pool(x)
    avg_out2 = self.fc1(avg_out1)
    avg_out3 = self.fc2(avg_out2)
    max_out1 = self.max_pool(x)
    max_out2 = self.fc1(max_out1)
    max_out3 = self.fc2(max_out2)
    add_out = tf.math.add(max_out3,avg_out3)
    channel_att = self.activation(add_out)
    return channel_att 

In [None]:
class Spatial_Attention_3D(tf.keras.layers.Layer) :
  def __init__(self) :
    super(Spatial_Attention_3D,self).__init__()
    self.conv3d = Conv3D(1,(7,7,7),padding='same',activation='sigmoid')
    self.avg_pool_chl = Lambda(lambda x:tf.keras.backend.mean(x,axis=4,keepdims=True))
    self.max_pool_chl = Lambda(lambda x:tf.keras.backend.max(x,axis=4,keepdims=True)) 
  
  def call(self,x) :
    avg_out1 = self.avg_pool_chl(x)
    max_out1 = self.max_pool_chl(x)
    concat_out = tf.concat([avg_out1,max_out1],axis=-1)
    spatial_att = self.conv3d(concat_out)
    return spatial_att 

In [None]:
class CBAM_3D(tf.keras.layers.Layer) :
  def __init__(self,C,ratio) :
    super(CBAM_3D,self).__init__()
    self.C = C
    self.ratio = ratio
    self.channel_attention = Channel_Attention_3D(self.C,self.ratio)
    self.spatial_attention = Spatial_Attention_3D()
  def call(self,y,H,W,D,C) :
    ch_out1 = self.channel_attention(y)
    ch_out2 = tf.expand_dims(ch_out1, axis=1)
    ch_out3 = tf.expand_dims(ch_out2, axis=2)
    ch_out4 = tf.expand_dims(ch_out3, axis=3)
    ch_out5 = tf.tile(ch_out4, multiples=[1,H,W,D,1])
    ch_out5 = tf.math.multiply(ch_out5,y)
    sp_out1 = self.spatial_attention(ch_out5)
    sp_out2 = tf.tile(sp_out1, multiples = [1,1,1,1,C])
    sp_out3 = tf.math.multiply(sp_out2,ch_out5)
    return sp_out3    

In [None]:
input_layer = layers.Input(shape = (im_height, im_width, im_depth, im_channel))
out1 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(input_layer)
out2 = CBAM_3D(out1.shape[4],4)(out1,out1.shape[1],out1.shape[2],out1.shape[3],out1.shape[4])
out2 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out1)
out2 = CBAM_3D(out2.shape[4],4)(out2,out2.shape[1],out2.shape[2],out2.shape[3],out2.shape[4])
out3 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out2)
out4 = layers.Add()([out1, out3])  #Concatenate()
out5 = layers.MaxPool3D(pool_size=(2, 2, 4), strides=None, padding='same')(out4)

out6 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out5)
out6 = CBAM_3D(out6.shape[4],4)(out6,out6.shape[1],out6.shape[2],out6.shape[3],out6.shape[4])
out7 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out6)
out7 = CBAM_3D(out7.shape[4],4)(out7,out7.shape[1],out7.shape[2],out7.shape[3],out7.shape[4])
out8 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out7)
out9 = layers.Add()([out6, out8])  #Concatenate()
out10 = layers.MaxPool3D(pool_size=(2, 2, 2), strides=None, padding='same')(out9)
out10 = CBAM_3D(out10.shape[4],4)(out10,out10.shape[1],out10.shape[2],out10.shape[3],out10.shape[4])
out11 = layers.Conv3D(filters=32, kernel_size=(3,3,3), padding='valid',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out10)
out12 = layers.Flatten()(out11)
FE_model = Model(inputs=input_layer,outputs=out12,name='3DResCNN')
FE_model.summary()

**VAE**

In [None]:
latent_dim = 8
batch_size = 45

In [None]:
def sampling(args):
  z_mean, z_log_sigma = args
  epsilon = tf.random.normal([latent_dim,], 0, 1, tf.float32)  #batch_size
  return z_mean + tf.math.exp(z_log_sigma) * epsilon

In [None]:
input_Feature = layers.Input(shape = (64,))
encoded_L1 = layers.Dense(32, activation='relu')(input_Feature)
encoded_L2 = layers.Dense(16, activation='relu')(encoded_L1)

z_mean = layers.Dense(latent_dim)(encoded_L2)
z_log_sigma = layers.Dense(latent_dim)(encoded_L2)
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])

decoded_h1 = layers.Dense(16, activation='relu')
decoded_h2 = layers.Dense(32, activation='relu')
decoded_mean = layers.Dense(64, activation='relu')

decoded_L1 =  decoded_h1(z)
decoded_L2 =  decoded_h2(decoded_L1)
decoded_X =  decoded_mean(decoded_L2)
VAE = Model(inputs=input_Feature,outputs=[decoded_X,z_mean,z_log_sigma],name='VAE')
VAE.summary()


In [None]:
def vae_loss(x, x_decoded_mean, z_mean, z_log_sigma):
  #Recon_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(x, x_decoded_mean))
  recon_loss = tf.reduce_mean(tf.math.square(tf.math.subtract(x,x_decoded_mean)))
  kl_loss = - 0.5 * tf.reduce_mean(1 + z_log_sigma - tf.math.square(z_mean) - tf.math.exp(z_log_sigma))
  return recon_loss + kl_loss

In [None]:
CQ = 6 #C2
CS = 3 #C1
N = 15
K = 5
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00001)   #1e-1

Outlier NN

In [None]:
input_Feature = layers.Input(shape = (3,))
encoded_L1 = layers.Dense(16, activation='relu')(input_Feature)
encoded_L2 = layers.Dense(8, activation='relu')(encoded_L1)
decoded = layers.Dense(2, activation='softmax')(encoded_L2)
outlier_nn = Model(inputs=input_Feature,outputs=decoded,name='ONN')
outlier_nn.summary()

# Protototypical Network

In [None]:
def calc_euclidian_dists(x, y):
  # x : (n,d)
  # y : (m,d)
    n = x.shape[0]
    m = y.shape[0]
    x = tf.tile(tf.expand_dims(x, 1), [1, m, 1])
    y = tf.tile(tf.expand_dims(y, 0), [n, 1, 1])
    return tf.reduce_mean(tf.math.pow(x - y, 2), 2)

In [None]:
emb_dim = 64
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optim1 = tf.keras.optimizers.Adam(0.0001) 
optim2 = tf.keras.optimizers.Adam(0.0001) 
optim3 = tf.keras.optimizers.Adam(0.0001) 
checkpoint_dir = '/content/drive/MyDrive/Hyperspectral OSR/CBAM_ResCNN_VAE_CLF_new'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optim1=optim1, optim2=optim2, optim3=optim3, FE_model = FE_model, VAE = VAE, outlier_nn=outlier_nn)

# Train

In [None]:
ntimes = 10
scce = tf.keras.losses.SparseCategoricalCrossentropy()

def proto_train(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,K,CS,CQ,N):#5,3,6,15     
    outlier = 0
    sembed = FE_model(ep_class_images)                             # [15, 64]        
    qembed = FE_model(ep_query_images)                             # [90, 64]
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)  # (90, 3) 
    y_true = np.zeros(len(ep_query_labels)) #for storing labels of classes, 0 for unseen; 1,2,3 for the three classes
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
      if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
    z_prototypes = tf.reshape(sembed,[CS, K, sembed.shape[-1]])           # [3, 5, 64]
    z_prototypes = tf.math.reduce_mean(z_prototypes, axis=1)        # [3, 64]   
    # Vautoencoder Loss on Query + Support
    rec_kl_loss = 0
    clf_loss = 0
    sqembedK = np.zeros((CS*(N+K),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(N+K),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==1:     # k query 
        y_sqK[j,:] = y_query[i,:]
        sqembedK[j,:] = qembed[i,:]
        j = j + 1
    for i in range(len(sembed)) :
      sqembedK[j,:] = sembed[i,:]
      y_sqK[j,:] = y_support[i,:]
      j = j + 1

    with tf.GradientTape() as VAEtape:
      for n in range(ntimes):
          gen_sqembedK, z_mean, z_log_sigma = VAE(sqembedK)
          rec_kl_loss = rec_kl_loss + vae_loss(sqembedK,gen_sqembedK, z_mean, z_log_sigma)
          dists_genK = calc_euclidian_dists(gen_sqembedK, z_prototypes) 
          log_p_y_genK = tf.nn.log_softmax(-dists_genK,axis=-1)
          clf_loss = clf_loss - tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sqK, log_p_y_genK), axis=-1))) 
      total_loss = clf_loss+rec_kl_loss
    grads = VAEtape.gradient(total_loss, VAE.trainable_variables)
    optim1.apply_gradients(zip(grads, VAE.trainable_variables))
    
    # Query set Augmentation((S + QK)(Original + Gen) + QU) [CEC loss for FE]
    sqembed_gen_K, mean, sigma = VAE(sqembedK)
    y_sq_Aug = y_sqK            # (Q+S)K + (Q+S)K(Gen) + QU
    sqembed_Aug = sqembedK              # query + support knowns
    sqembed_Aug = tf.concat((sqembed_Aug,sqembed_gen_K),axis=0)  # stacking known [k class]>AE o/p  
    y_sq_Aug = tf.concat((y_sq_Aug,y_sqK),axis=0)
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==0:     # u query 
        y_sq_Aug = tf.concat((y_sq_Aug,tf.expand_dims(y_query[i,:],axis=0)),axis=0)
        sqembed_Aug = tf.concat((sqembed_Aug,tf.expand_dims(qembed[i,:],axis=0)),axis=0)
    dists_Aug = calc_euclidian_dists(sqembed_Aug, z_prototypes) 
    log_p_y_Aug = tf.nn.log_softmax(-dists_Aug,axis=-1)
    cec_loss = -tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sq_Aug, log_p_y_Aug), axis=-1))) 
    
    #outlier detection [outlier network update]
    with tf.GradientTape() as outlier_tape:
      outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
      y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
      for i in range(len(y_sq_Aug)) :
        if i < 2*(CS*(K+N)):
          y_outlier[i] = 1
      outlier_loss = 10*scce(y_outlier,outlier_pred)
    grads = outlier_tape.gradient(outlier_loss, outlier_nn.trainable_variables)
    optim2.apply_gradients(zip(grads, outlier_nn.trainable_variables))
    
    # outlier loss calculation for FE update
    outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
    y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
    for i in range(len(y_sq_Aug)) :
      if i < 2*(CS*(K+N)):
        y_outlier[i] = 1
    outlier_loss = 10*scce(y_outlier,outlier_pred)

    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
        if outlier_index[i] == 1 :
          x = support_classes.index(ep_query_labels[i])
          if x == pred_index[i] :
            correct_pred += 1  
      else :
          if outlier_index[i] == 0 :
            outlier = outlier + 1  
    accuracy = correct_pred/(CS*N)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*N)

    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
          y_pred[i] = pred_index[i]+1
        else :
          y_pred[i] = 0
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))
      
    loss = cec_loss + outlier_loss
    return loss, accuracy, outlier_det_acc, open_oa    # scalar, scalar

# Metrics to gather
train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
train_openoa = tf.metrics.Mean(name='train_openoa')
train_outlier_acc = tf.metrics.Mean(name='train_outlier_acc')

def train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,K,CS,CQ,N):
    # Forward & update gradients
    with tf.GradientTape() as tape:
        loss, accuracy, outlier_det_acc, openoa = proto_train(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,K,CS,CQ,N)
    gradients = tape.gradient(loss, FE_model.trainable_variables)
    optim3.apply_gradients(zip(gradients, FE_model.trainable_variables))
    train_loss(loss)
    train_acc(accuracy)
    train_openoa(openoa)
    train_outlier_acc(outlier_det_acc)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

for epoch in range(5001): # 80 train + 80 tune + 100 train + 160 tune + 40 train
    train_loss.reset_states()  
    train_acc.reset_states()
    train_openoa.reset_states()
    train_outlier_acc.reset_states()
    for epi in range(10): 
        tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = new_episode(patches_class_ip,K,N,CS,CQ,train_class_labels)   
        train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,K,CS,CQ,N)
        
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', train_acc.result(), step=epoch)
        tf.summary.scalar('openoa',train_openoa.result(), step=epoch)
        tf.summary.scalar('outlier_det_acc',train_outlier_acc.result(), step=epoch)

    template = 'Epoch {}, Train Loss: {:.2f}, Train Accuracy: {:.2f}, Train Open OA: {:.2f}, Train Outlier Det. Acc: {:.2f}'
    print(template.format(epoch+1,train_loss.result(),train_acc.result()*100,train_openoa.result()*100,train_outlier_acc.result()*100))

    if epoch % 500 == 0 and epoch != 0 :
      checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/gradient_tape

# Fine-tuning





In [None]:
checkpoint_dir_tune = '/content/drive/MyDrive/Hyperspectral OSR/New_IP_ckpts'
checkpoint_prefix_tune = os.path.join(checkpoint_dir_tune, "ckpt")
checkpoint_tune = tf.train.Checkpoint(optimizer=optimizer,FE_model = FE_model, VAE = VAE, outlier_nn=outlier_nn)

In [None]:
train_class_indices = [1,2,4,5,7,9,10,11,13,14]
test_class_indices = [0,3,6,8,12,15]
train_patches_class = [patches_class_ip[i] for i in train_class_indices]        #(10)
test_patches_class = [patches_class_ip[i] for i in test_class_indices]        #(6) 
train_class_labels = [2,3,5,6,8,10,11,12,14,15]   
test_class_labels = [1,4,7,9,13,16]     #[11...16]
test_support_labels = [16,4,13]
ft_labels = [2,3,4,5,6,8,10,11,12,13,14,15,16]

In [None]:
tune_set_5 = [[] for i in range(16)]
for j in range(1,17) :
  if j in train_class_labels :
    tune_set_5[j-1] = patches_class_ip[j-1] 
  elif j in test_support_labels :
    tune_set_5[j-1] = patches_class_ip[j-1][:5,:,:,:,:] # for each class first 5 samples taken

In [None]:
def tune_episode(patches_list,NS,NQ,CS,CQ,class_labels) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = list(np.random.choice(class_labels,CQ,replace=False))  # Randomly choice 6 Query Classes
    support_classes = list(np.random.choice(selected_classes,CS,replace=False))  # Randomly choice 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes    

In [None]:
tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = tune_episode(tune_set_5,1,4,3,6,ft_labels)

In [None]:
emb_dim = 64
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
ntimes = 10
tK = 1
tN = 4
optim1 = tf.keras.optimizers.Adam(0.00001) 
optim2 = tf.keras.optimizers.Adam(0.00001) 
optim3 = tf.keras.optimizers.Adam(0.00001)
scce = tf.keras.losses.SparseCategoricalCrossentropy()

def proto_tune(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,tK,CS,CQ,tN):#5,3,6,15     
    outlier = 0
    sembed = FE_model(ep_class_images)                             # [15, 64]        
    qembed = FE_model(ep_query_images)                             # [90, 64]
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)  # (90, 3) 
    y_true = np.zeros(len(ep_query_labels)) #for storing labels of classes, 0 for unseen; 1,2,3 for the three classes
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
      if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
    z_prototypes = tf.reshape(sembed,[CS, tK, sembed.shape[-1]])           # [3, 5, 64]
    z_prototypes = tf.math.reduce_mean(z_prototypes, axis=1)        # [3, 64]   
    # Vautoencoder Loss on Query + Support
    rec_kl_loss = 0
    clf_loss = 0
    sqembedK = np.zeros((CS*(tN+tK),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(tN+tK),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==1:     # k query 
        y_sqK[j,:] = y_query[i,:]
        sqembedK[j,:] = qembed[i,:]
        j = j + 1
    for i in range(len(sembed)) :
      sqembedK[j,:] = sembed[i,:]
      y_sqK[j,:] = y_support[i,:]
      j = j + 1

    with tf.GradientTape() as VAEtape:
      for n in range(ntimes):
          gen_sqembedK, z_mean, z_log_sigma = VAE(sqembedK)
          rec_kl_loss = rec_kl_loss + vae_loss(sqembedK,gen_sqembedK, z_mean, z_log_sigma)
          dists_genK = calc_euclidian_dists(gen_sqembedK, z_prototypes) 
          log_p_y_genK = tf.nn.log_softmax(-dists_genK,axis=-1)
          clf_loss = clf_loss - tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sqK, log_p_y_genK), axis=-1))) 
      total_loss = clf_loss+rec_kl_loss
    grads = VAEtape.gradient(total_loss, VAE.trainable_variables)
    optim1.apply_gradients(zip(grads, VAE.trainable_variables))
    
    # Query set Augmentation((S + QK)(Original + Gen)) [CEC loss for FE]
    sqembed_gen_K, mean, sigma = VAE(sqembedK)
    y_sq_Aug = y_sqK            # (Q+S)K + (Q+S)K(Gen) + QU
    sqembed_Aug = sqembedK              # query + support knowns
    sqembed_Aug = tf.concat((sqembed_Aug,sqembed_gen_K),axis=0)  # stacking known [k class]>AE o/p  
    y_sq_Aug = tf.concat((y_sq_Aug,y_sqK),axis=0)
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==0:     # u query 
        y_sq_Aug = tf.concat((y_sq_Aug,tf.expand_dims(y_query[i,:],axis=0)),axis=0)
        sqembed_Aug = tf.concat((sqembed_Aug,tf.expand_dims(qembed[i,:],axis=0)),axis=0)
    dists_Aug = calc_euclidian_dists(sqembed_Aug, z_prototypes) 
    log_p_y_Aug = tf.nn.log_softmax(-dists_Aug,axis=-1)
    cec_loss = -tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sq_Aug, log_p_y_Aug), axis=-1))) 
    #print(len(y_sq_Aug))
    
    #outlier detection [outlier network update]
    with tf.GradientTape() as outlier_tape:
      outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
      y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
      for i in range(len(y_sq_Aug)) :
        if i < 2*(CS*(tK+tN)):
          y_outlier[i] = 1
      outlier_loss = 10*scce(y_outlier,outlier_pred)
    grads = outlier_tape.gradient(outlier_loss, outlier_nn.trainable_variables)
    optim2.apply_gradients(zip(grads, outlier_nn.trainable_variables))
    
    # outlier loss calculation for FE update
    outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
    y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
    for i in range(len(y_sq_Aug)) :
      if i < 2*(CS*(tK+tN)):
        y_outlier[i] = 1
    outlier_loss = 10*scce(y_outlier,outlier_pred)

    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
        if outlier_index[i] == 1 :
          x = support_classes.index(ep_query_labels[i])
          if x == pred_index[i] :
            correct_pred += 1  
      else :
          if outlier_index[i] == 0 :
            outlier = outlier + 1  
    accuracy = correct_pred/(CS*tN)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*tN)


    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
          y_pred[i] = pred_index[i]+1
        else :
          y_pred[i] = 0
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))
      
    loss = cec_loss + outlier_loss
    return loss, accuracy, outlier_det_acc, open_oa    # scalar, scalar

optimizer = tf.keras.optimizers.Adam(0.00001) 
# Metrics to gather
tune_loss = tf.metrics.Mean(name='tune_loss')
tune_acc = tf.metrics.Mean(name='tune_accuracy')
tune_open_acc = tf.metrics.Mean(name='tune_open_accuracy')
tune_outlier_acc = tf.metrics.Mean(name='tune_outlier_accuracy')

def tune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,CS,CQ,tN):
    # Forward & update gradients
    with tf.GradientTape() as tape:
        loss, accuracy, outlier_det_acc, open_oa = proto_tune(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,CS,CQ,tN)
    gradients = tape.gradient(loss, FE_model.trainable_variables)
    optim3.apply_gradients(zip(gradients, FE_model.trainable_variables))
    # Log loss and accuracy for step
    tune_loss(loss)
    tune_acc(accuracy)
    tune_open_acc(open_oa)
    tune_outlier_acc(outlier_det_acc)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tune_log_dir = 'logs/gradient_tape/' + current_time + '/tune'
tune_summary_writer = tf.summary.create_file_writer(tune_log_dir)
        
for epoch in range(1001): 
    tune_loss.reset_states()  
    tune_acc.reset_states()
    tune_open_acc.reset_states()
    tune_outlier_acc.reset_states()
    for epi in range(10): 
        tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = tune_episode(tune_set_5,1,4,3,6,ft_labels)   
        tune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,CS,CQ,tN)
    
    with tune_summary_writer.as_default():
        tf.summary.scalar('loss', tune_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', tune_acc.result(), step=epoch)
        tf.summary.scalar('Open_accuracy', tune_open_acc.result(), step=epoch)
        tf.summary.scalar('Outlier_accuracy', tune_outlier_acc.result(), step=epoch)

    template = 'Epoch {}, Tune Loss: {:.2f}, Tune Accuracy: {:.2f}, Open Accuracy: {:.2f},Outlier Accuracy: {:.2f}'
    print(template.format(epoch+1,tune_loss.result(),tune_acc.result()*100,tune_open_acc.result()*100,tune_outlier_acc.result()*100))

    if epoch % 100 == 0 and epoch != 0 :
      checkpoint_tune.save(file_prefix = checkpoint_prefix_tune)

# Testing after finetuning

In [None]:
def proto_test(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,K,CS,CQ,N):#5,3,6,15     
    outlier = 0
    sembed = FE_model(ep_class_images)                             # [15, 64]        
    qembed = FE_model(ep_query_images)                             # [90, 64]
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)  # (90, 3) 
    y_true = np.zeros(len(ep_query_labels)) #for storing labels of classes, 0 for unseen; 1,2,3 for the three classes
    y_auc = np.zeros((len(ep_query_labels))) #for storing labels, 1 for seen, and 0 for unseen
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
            y_auc[i] = 1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
      if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
    z_prototypes = tf.reshape(sembed,[CS, K, sembed.shape[-1]])           # [3, 5, 64]
    z_prototypes = tf.math.reduce_mean(z_prototypes, axis=1)        # [3, 64]   

    sqembedK = np.zeros((CS*(N+K),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(N+K),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==1:     # k query 
        y_sqK[j,:] = y_query[i,:]
        sqembedK[j,:] = qembed[i,:]
        j = j + 1
    for i in range(len(sembed)) :
      sqembedK[j,:] = sembed[i,:]
      y_sqK[j,:] = y_support[i,:]
      j = j + 1
    
    #RE calculation
    qembedK = np.zeros((CS*N,emb_dim))
    qembedU = np.zeros(((CQ-CS)*N,emb_dim))
    y_queryK = np.zeros((CS*N,CS))
    j = 0
    k = 0
    for i in range(len(y_query)):     # 90
      if np.sum(y_query[i,:])==1:     # k query 
        y_queryK[j,:] = y_query[i,:]
        qembedK[j,:] = qembed[i,:]
        j = j + 1
      else :
        qembedU[k,:] = qembed[i,:]
        k = k + 1
    gen_qembedK, z_mean, z_log_sigma = VAE(qembedK)
    gen_qembedU, z_mean, z_log_sigma = VAE(qembedU)
    rec_loss_k = abs(tf.reduce_mean(gen_qembedK-qembedK))
    rec_loss_u = abs(tf.reduce_mean(gen_qembedU-qembedU))  

    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_prob1 = outlier_prob
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
      if ep_query_labels[i] in support_classes :
        if outlier_index[i] == 1 :
          x = support_classes.index(ep_query_labels[i])
          if x == pred_index[i] :
            correct_pred += 1  
      else :
          if outlier_index[i] == 0 :
            outlier = outlier + 1  
    accuracy = correct_pred/(CS*N)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*N)
    y_score = np.zeros((len(ep_query_labels)))
    for i in range(len(ep_query_labels)) :
      y_score[i] = outlier_prob1[i,1]
    auc = sklearn.metrics.roc_auc_score(y_auc, y_score)


    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
          y_pred[i] = pred_index[i]+1
        else :
          y_pred[i] = 0
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))    
    return accuracy, open_oa, outlier_det_acc, auc    # scalar, scalar

In [None]:
def test_episode(patches_list,NS,NQ,CS,CQ) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = test_class_labels  # 6 Query Classes
    support_classes = test_support_labels  # 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes    

In [None]:
total_acc = 0 
total_open_oa = 0
total_outlier_acc = 0 
rec_k = 0
rec_u = 0
f1score = 0 
total_auc = 0
emb_dim = 64
tepochs = 100
K = 5
N = 15
for i in range(tepochs) :
  tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = test_episode(patches_class_ip,5,15,3,6)   
  accuracy, open_oa, outlier_det_acc, auc = proto_test(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,5,3,6,15)
  total_acc = total_acc + accuracy
  total_open_oa = total_open_oa + open_oa
  total_outlier_acc = total_outlier_acc + outlier_det_acc
  total_auc = total_auc + auc
print('accuracy',total_acc*100/tepochs)
print('Outlier detection accuracy', (total_outlier_acc*100/tepochs))
print('open oa',total_open_oa*100/tepochs)
print('auc',total_auc/tepochs)