# MORGAN: Meta-Learning-based Few-Shot Open-Set Recognition via Generative Adversarial Network

## Import python libraries

In [None]:
import os
import cv2 # 4.1.2
import numpy as np # 1.21.5
import matplotlib # 3.2.2
import datetime
import random

import sklearn # 1.0.2
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

import scipy # 1.4.1
import scipy.io as sio

import tensorflow as tf # 2.8.0
import tensorflow.keras.backend as K
from tensorflow.keras.models import load_model,Model
from tensorflow.keras import Sequential,layers,regularizers
from tensorflow.keras.layers import Conv2D,BatchNormalization, MaxPool2D, Activation, Flatten, Dense, GlobalAveragePooling2D, GlobalMaxPool2D, AveragePooling2D, Lambda, Reshape, UpSampling2D, Conv2DTranspose 
from tensorflow.keras.layers import Conv3D,BatchNormalization, MaxPool3D, Activation, Flatten, Dense, GlobalAveragePooling3D, GlobalMaxPool3D, AveragePooling3D, Lambda
import keras # 2.8.0

## Data load and pre-process

In [None]:
def applyPCA(X, numComponents=75):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

def padWithZeros(X, margin=2):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize, removeZeroLabels = True):
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)  # X :(145, 145, 30) --> (195, 195, 30) with window =25
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))  # (21025, 25, 25, 30)   
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))  # (21025,)
    patchIndex = 0
    
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]  
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]            
            patchIndex = patchIndex + 1
  
    patchesData = np.expand_dims(patchesData, axis=-1)
    return patchesData,patchesLabels

def patches_class(X,Y,n):
    n_classes = n
    patches_list = []
    labeles_list = []
    for i in range(1,n_classes+1):   # not considering class 0
        patchesData_Ith_Label = X[Y==i,:,:,:,:]
        Ith_Label = Y[Y==i]
        patches_list.append(patchesData_Ith_Label)
        labeles_list.append(Ith_Label)
        
    return patches_list,labeles_list

##### Load Indian Pines (or any other Hyperspectral) Dataset

In [None]:
windowSize = 11
im_height, im_width, im_depth, im_channel = windowSize, windowSize, 30, 1

X = sio.loadmat('/content/drive/MyDrive/data/Indian_pines_corrected.mat')['indian_pines_corrected']
y = sio.loadmat('/content/drive/MyDrive/data/Indian_pines_gt.mat')['indian_pines_gt']
print(X.shape, y.shape)

X,pca = applyPCA(X,numComponents=im_depth)
print(X.shape, y.shape)

X, y = createImageCubes(X, y, windowSize)
print(X.shape, y.shape)

patches_class_ip,label_ip = patches_class(X,y,16) # class_wise list of patches #(16,) for class 0: (2009, 9, 9, 20, 1)

## Define Meta-training and Meta-testing classes

In [None]:
train_class_indices = [1,2,4,5,7,9,10,11,13,14]    # 10 classes  
train_class_labels = [2,3,5,6,8,10,11,12,14,15]

test_class_indices = [0,3,6,8,12,15]               # 6 classes
test_class_labels = [1,4,7,9,13,16]

## A. Define Model

#### A.1.1 CBAM3D layer for feature extractor  
Source: "Few-Shot Open-Set Recognition of Hyperspectral Images with Outlier Calibration Network"

In [None]:
class Channel_Attention_3D(tf.keras.layers.Layer) :
    def __init__(self,C,ratio) :
        super(Channel_Attention_3D,self).__init__()
        self.avg_pool = GlobalAveragePooling3D()
        self.max_pool = GlobalMaxPool3D()
        self.activation = Activation('sigmoid')
        self.fc1 = Dense(C/ratio, activation = 'relu')
        self.fc2 = Dense(C)
    def call(self,x) :
        avg_out1 = self.avg_pool(x)
        avg_out2 = self.fc1(avg_out1)
        avg_out3 = self.fc2(avg_out2)
        max_out1 = self.max_pool(x)
        max_out2 = self.fc1(max_out1)
        max_out3 = self.fc2(max_out2)
        add_out = tf.math.add(max_out3,avg_out3)
        channel_att = self.activation(add_out)
        return channel_att

class Spatial_Attention_3D(tf.keras.layers.Layer) :
    def __init__(self) :
        super(Spatial_Attention_3D,self).__init__()
        self.conv3d = Conv3D(1,(7,7,7),padding='same',activation='sigmoid')
        self.avg_pool_chl = Lambda(lambda x:tf.keras.backend.mean(x,axis=4,keepdims=True))
        self.max_pool_chl = Lambda(lambda x:tf.keras.backend.max(x,axis=4,keepdims=True)) 
        
    def call(self,x) :
        avg_out1 = self.avg_pool_chl(x)
        max_out1 = self.max_pool_chl(x)
        concat_out = tf.concat([avg_out1,max_out1],axis=-1)
        spatial_att = self.conv3d(concat_out)
        return spatial_att

class CBAM_3D(tf.keras.layers.Layer) :
    def __init__(self,C,ratio) :
        super(CBAM_3D,self).__init__()
        self.C = C
        self.ratio = ratio
        self.channel_attention = Channel_Attention_3D(self.C,self.ratio)
        self.spatial_attention = Spatial_Attention_3D()
    def call(self,y,H,W,D,C) :
        ch_out1 = self.channel_attention(y)
        ch_out2 = tf.expand_dims(ch_out1, axis=1)
        ch_out3 = tf.expand_dims(ch_out2, axis=2)
        ch_out4 = tf.expand_dims(ch_out3, axis=3)
        ch_out5 = tf.tile(ch_out4, multiples=[1,H,W,D,1])
        ch_out5 = tf.math.multiply(ch_out5,y)
        sp_out1 = self.spatial_attention(ch_out5)
        sp_out2 = tf.tile(sp_out1, multiples = [1,1,1,1,C])
        sp_out3 = tf.math.multiply(sp_out2,ch_out5)
        return sp_out3

#### A.1 Feature Extractor

In [None]:
input_layer = layers.Input(shape = (im_height, im_width, im_depth, im_channel))
out1 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(input_layer)
out2 = CBAM_3D(out1.shape[4],4)(out1,out1.shape[1],out1.shape[2],out1.shape[3],out1.shape[4])
out2 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out1)
out2 = CBAM_3D(out2.shape[4],4)(out2,out2.shape[1],out2.shape[2],out2.shape[3],out2.shape[4])
out3 = layers.Conv3D(filters=8, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out2)
out4 = layers.Add()([out1, out3])  #Concatenate()
out5 = layers.MaxPool3D(pool_size=(2, 2, 4), strides=None, padding='same')(out4)

out6 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out5)
out6 = CBAM_3D(out6.shape[4],4)(out6,out6.shape[1],out6.shape[2],out6.shape[3],out6.shape[4])
out7 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out6)
out7 = CBAM_3D(out7.shape[4],4)(out7,out7.shape[1],out7.shape[2],out7.shape[3],out7.shape[4])
out8 = layers.Conv3D(filters=16, kernel_size=(3,3,3), padding='same',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out7)
out9 = layers.Add()([out6, out8])  #Concatenate()
out10 = layers.MaxPool3D(pool_size=(2, 2, 2), strides=None, padding='same')(out9)
out10 = CBAM_3D(out10.shape[4],4)(out10,out10.shape[1],out10.shape[2],out10.shape[3],out10.shape[4])
out11 = layers.Conv3D(filters=32, kernel_size=(3,3,3), padding='valid',activation='relu',input_shape=(im_height, im_width, im_depth, im_channel))(out10)
out12 = layers.Flatten()(out11)
FE_model = Model(inputs=input_layer,outputs=out12,name='R3CBAM')
FE_model.summary()

#### A.2 Outlier detector

In [None]:
input_Feature = layers.Input(shape = (3,))
encoded_L1 = layers.Dense(16, activation='relu')(input_Feature)
encoded_L2 = layers.Dense(8, activation='relu')(encoded_L1)
decoded = layers.Dense(2, activation='softmax')(encoded_L2)
outlier_nn = Model(inputs=input_Feature,outputs=decoded,name='ONN')
outlier_nn.summary()

#### A.3 GAN1 : to generate pseudo-known samples: Input low noise variance

##### A.3.1 Generator (Low)

In [None]:
generator_input_size=16+8
input_feature=layers.Input(shape=(generator_input_size,))
layer_low_g1=layers.Dense(32,activation='relu')(input_feature) 
layer_low_g2=layers.Dense(48,activation='relu')(layer_low_g1)
layer_low_g3=layers.Dense(64,activation='relu')(layer_low_g2)
generator_nn_low=Model(inputs=input_feature,outputs=layer_low_g3, name='generator_low')
generator_nn_low.summary()

##### A.3.2 Discriminator (Low)

In [None]:
input_feature=layers.Input(shape=(80,))
layer_l_d1=layers.Dense(48,activation='relu')(input_feature) 
layer_l_d2=layers.Dense(32,activation='relu')(layer_l_d1)
layer_l_d3=layers.Dense(16,activation='relu')(layer_l_d2)
layer_l_d4=layers.Dense(8,activation='relu')(layer_l_d3)
layer_l_d5=layers.Dense(1)(layer_l_d4)
dis_nn_low=Model(inputs=input_feature,outputs=layer_l_d5, name='discriminator_low')
dis_nn_low.summary()

#### A.3.3 Function to optimize GAN 1

In [None]:
def gan_reptile(sembed,ep_class_labels,generator,dis,optim_d,optim_g,stdev,alpha1):
    batch_size=sembed.shape[0]
    one_hot_labels=tf.one_hot(ep_class_labels, depth=16, axis=-1)
    one_hot_labels=tf.reshape(one_hot_labels,(batch_size,16))
    
    vector=tf.random.normal(shape=(batch_size,8), stddev=stdev)
    latent=tf.concat([vector,one_hot_labels], axis=1)
    
    # loss_function
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True)
    
    # generators_low
    generated_emb=generator(latent)
    
    fake_embs=tf.concat([generated_emb, one_hot_labels], axis=1)  
    real_embs=tf.concat([sembed, one_hot_labels], axis=1)
    combined_embs=tf.concat([fake_embs, real_embs], axis=0)
    labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)  
    
    # First-order meta-Learning - Reptile
    # Discriminator update
    old_vars_gen=generator.get_weights()
    old_vars_dis=dis.get_weights()
    with tf.GradientTape() as dis_tape:
        predictions=dis(combined_embs)
        d_loss=loss_fn(labels,predictions)
    grads=dis_tape.gradient(d_loss,dis.trainable_variables)
    optim_d.apply_gradients(zip(grads, dis.trainable_variables))
    new_vars_dis=dis.get_weights()
    
    for var in range(len(new_vars_dis)):
        new_vars_dis[var]=old_vars_dis[var] + ((new_vars_dis[var]-old_vars_dis[var])*alpha1)
    dis.set_weights(new_vars_dis) 
    
    # Generator
    mis_labels=tf.zeros((batch_size,1))
    old_gen=generator.get_weights()
    with tf.GradientTape() as gen_tape:
        vector=tf.random.normal(shape=(batch_size,8), stddev=0.2)
        latent=tf.concat([vector,one_hot_labels], axis=1)
        fake_emb=generator(latent)
        fake_emb_and_labels=tf.concat([fake_emb, one_hot_labels], axis=-1)
        predictions=dis(fake_emb_and_labels)
        g_loss=loss_fn(mis_labels,predictions)
    g_grads=gen_tape.gradient(g_loss, generator.trainable_variables)
    optim_g.apply_gradients(zip(g_grads, generator.trainable_variables))
    new_gen=generator.get_weights()
    
    for var in range(len(new_gen)):
        new_gen[var]=old_gen[var] + ((new_gen[var]-old_gen[var])* alpha1)
    generator.set_weights(new_gen)  
    
    vector=tf.random.normal(shape=(batch_size,8), stddev=stdev)
    latent=tf.concat([vector,one_hot_labels], axis=1)
    sembed_gen=generator(latent)
    return sembed_gen,vector

#### A.4 GAN2 : to generate pseudo-unknown samples: Input high noise variance

##### A.4.1 Generator (High)

In [None]:
generator_input_size=16+8
input_feature=layers.Input(shape=(generator_input_size,))
layer_h_g1=layers.Dense(32,activation='relu')(input_feature) 
layer_h_g2=layers.Dense(48,activation='relu')(layer_h_g1)
layer_h_g3=layers.Dense(64,activation='relu')(layer_h_g2)
generator_nn_high=Model(inputs=input_feature,outputs=layer_h_g3, name='generator_high')
generator_nn_high.summary()

##### A.4.2 Discriminator (High)

In [None]:
input_feature=layers.Input(shape=(80,))
layer_h_d1=layers.Dense(48,activation='relu')(input_feature) 
layer_h_d2=layers.Dense(32,activation='relu')(layer_h_d1)
layer_h_d3=layers.Dense(16,activation='relu')(layer_h_d2)
layer_h_d4=layers.Dense(8,activation='relu')(layer_h_d3)
layer_h_d5=layers.Dense(1)(layer_h_d4)
dis_nn_high=Model(inputs=input_feature,outputs=layer_h_d5, name='discriminator_high')
dis_nn_high.summary()

##### A.4.3 AOL Regularizer

In [None]:
def cosine_loss(s1,s2,z1,z2):
    s1=tf.nn.l2_normalize(s1,dim=1)
    s2=tf.nn.l2_normalize(s2,dim=1)
    z1=tf.nn.l2_normalize(z1,dim=1)
    z2=tf.nn.l2_normalize(z2,dim=1)
    cos_s=s1*s2
    cos_z=z1*z2
    loss = (1+tf.reduce_sum(cos_z,axis=1))*(tf.math.maximum(0.0000001, tf.reduce_sum(cos_s,axis=1))) 
    return tf.reduce_mean(loss)

#### A.4.4 Function to optimize GAN 2

In [None]:
def gan_reptile_high(sembed,sembed_low,z1,ep_class_labels,generator,dis,optim_d,optim_g,stdev,alpha1):
    batch_size=sembed.shape[0]
    one_hot_labels=tf.one_hot(ep_class_labels, depth=16, axis=-1)
    one_hot_labels=tf.reshape(one_hot_labels,(batch_size,16))
    
    vector=tf.random.normal(shape=(batch_size,8), stddev=stdev)
    latent=tf.concat([vector,one_hot_labels], axis=1)
    
    # loss_function
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True)
    
    # generators_High
    generated_emb=generator(latent)
    # AOL regulazer loss
    c_lossd=cosine_loss(generated_emb,sembed_low,vector,z1)
    
    fake_embs=tf.concat([generated_emb, one_hot_labels], axis=1)  
    real_embs=tf.concat([sembed, one_hot_labels], axis=1)
    combined_embs=tf.concat([fake_embs, real_embs], axis=0)  
    labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)  
    
    # First-order meta-Learning - Reptile 
    # discriminator
    old_vars_gen=generator.get_weights()
    old_vars_dis=dis.get_weights()
    with tf.GradientTape() as dis_tape:
        predictions=dis(combined_embs)
        d_loss=loss_fn(labels,predictions) + c_lossd
    grads=dis_tape.gradient(d_loss,dis.trainable_variables)
    optim_d.apply_gradients(zip(grads, dis.trainable_variables))
    new_vars_dis=dis.get_weights()
    
    for var in range(len(new_vars_dis)):
        new_vars_dis[var]=old_vars_dis[var] + ((new_vars_dis[var]-old_vars_dis[var])*alpha1)
    dis.set_weights(new_vars_dis) 
    
    # generator
    mis_labels=tf.zeros((batch_size,1))
    old_gen=generator.get_weights()
    with tf.GradientTape() as gen_tape:
        vector=tf.random.normal(shape=(batch_size,8), stddev=0.2)
        latent=tf.concat([vector,one_hot_labels], axis=1)
        fake_emb=generator(latent)
        c_lossg=cosine_loss(fake_emb,sembed_low,vector,z1)
        fake_emb_and_labels=tf.concat([fake_emb, one_hot_labels], axis=-1)
        predictions=dis(fake_emb_and_labels)
        g_loss=loss_fn(mis_labels,predictions) + c_lossg
    g_grads=gen_tape.gradient(g_loss, generator.trainable_variables)
    optim_g.apply_gradients(zip(g_grads, generator.trainable_variables))
    new_gen=generator.get_weights()
    for var in range(len(new_gen)):
        new_gen[var]=old_gen[var] + ((new_gen[var]-old_gen[var])* alpha1)
    generator.set_weights(new_gen)  
    
    vector=tf.random.normal(shape=(batch_size,8), stddev=stdev)
    latent=tf.concat([vector,one_hot_labels], axis=1)
    sembed_gen=generator(latent)
    return sembed_gen

## B. Define Optimizer and Distance functions

In [None]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optim1 = tf.keras.optimizers.Adam(0.0001) 
optim2 = tf.keras.optimizers.Adam(0.0001) 
optim_d_low=tf.keras.optimizers.Adam(learning_rate=0.0001)
optim_g_low=tf.keras.optimizers.Adam(learning_rate=0.0001)
optim_d_high=tf.keras.optimizers.Adam(learning_rate=0.0001)
optim_g_high=tf.keras.optimizers.Adam(learning_rate=0.0001)
scce = tf.keras.losses.SparseCategoricalCrossentropy()
ce_loss = tf.keras.losses.CategoricalCrossentropy()

In [None]:
# Checkpoints to save meta-training updates
checkpoint_dir = '/content/Rp'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optim1=optim1, optim2=optim2, optim_d_low=optim_d_low, optim_g_low=optim_g_low, optim_d_high=optim_d_high,
                                 optim_g_high=optim_g_high, FE_model = FE_model,
                                 generator_nn_low=generator_nn_low, generator_nn_high=generator_nn_high,
                                 dis_nn_low=dis_nn_low, dis_nn_high=dis_nn_high,
                                 outlier_nn=outlier_nn)

#### Distance functions

In [None]:
def calc_euclidian_dists(x, y):
  # x : (n,d)
  # y : (m,d)
    n = x.shape[0]
    m = y.shape[0]
    x = tf.tile(tf.expand_dims(x, 1), [1, m, 1])
    y = tf.tile(tf.expand_dims(y, 0), [n, 1, 1])
    return tf.reduce_mean(tf.math.pow(x - y, 2), 2)

def rp_distance(embed,rp,num_rec_points=1,num_class=3):   
    # Reciprocal points : https://arxiv.org/pdf/2103.00953.pdf
    f_2=tf.reduce_sum(tf.math.pow(embed,2), axis=1, keepdims=True)
    c_2=tf.reduce_sum(tf.math.pow(rp,2), axis=1, keepdims=True)
    dist=f_2 - 2*tf.linalg.matmul(tf.cast(embed,tf.float64),tf.transpose(rp)) + tf.transpose(c_2)
    dist=dist/float(embed.shape[1])
    dist=tf.reshape(dist, [-1, num_class, num_rec_points])
    dist=tf.reduce_mean(dist, axis=2)
    return dist

In [None]:
emb_dim = 64

## C. Meta-training

#### Episode for Meta-training

In [None]:
def new_episode(patches_list,NS,NQ,CS,CQ,class_labels) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = list(np.random.choice(class_labels,CQ,replace=False))  # Randomly choice 6 Query Classes
    support_classes = list(np.random.choice(selected_classes,CS,replace=False))  # Randomly choice 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes

In [None]:
def meta_train(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,K,CS,CQ,N): #5,3,6,15     
    outlier=0
    sembed=FE_model(ep_class_images)
    qembed=FE_model(ep_query_images)
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)
    y_true = np.zeros(len(ep_query_labels))
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
        if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            
    # gan --> reptile
    ngan = 5 
    for _ in range(ngan):
        sembed_low,vector=gan_reptile(sembed,ep_class_labels,generator_nn_low,dis_nn_low,optim_d_low,optim_g_low,stdev=0.2,alpha1=0.003)
        sembed_high=gan_reptile_high(sembed,sembed_low,vector,ep_class_labels,generator_nn_high,dis_nn_high,optim_d_high,optim_g_high,stdev=1.0,alpha1=0.003)
        
    # Reciprocal prototypes
    z_proto_low = tf.reshape(sembed_low,[CS, K, sembed_low.shape[-1]])      # [3, 5, 64]
    z_prototypes = tf.reshape(sembed,[CS, K, sembed.shape[-1]])           # [3, 5, 64]
    z_prototypes = tf.math.reduce_mean((z_prototypes + z_proto_low), axis=1) # [3, 5, 64]
    
    rp_proto_low = tf.reshape(sembed,[CS, K, sembed.shape[-1]])
    rp_proto_low = tf.reduce_mean(rp_proto_low, axis=1)
    
    rp_proto_high = tf.reshape(qembed, [CQ, N, qembed.shape[-1]])
    rp_proto_high = tf.reduce_mean(rp_proto_high, axis=1)
    rp_proto_high = tf.reduce_mean(rp_proto_high, axis=0, keepdims=True)
    
    rpoints = tf.concat([rp_proto_low,rp_proto_high], 0)
    
    sqembedK = np.zeros((CS*(N+K),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(N+K),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
        if np.sum(y_query[i,:])==1:     # k query 
            y_sqK[j,:] = y_query[i,:]
            sqembedK[j,:] = qembed[i,:]   # [45,3]
            j = j + 1
            
    for i in range(len(sembed)) :
        sqembedK[j,:] = sembed[i,:]
        y_sqK[j,:] = y_support[i,:]
        j = j + 1
        
    # compute the reciprocal loss and train the reciprocal points
    sqembedKU = tf.concat([sembed, qembed], 0) 
    y_sqku = np.concatenate((y_support, y_query), 0)
    y_sqKU = np.zeros((y_sqku.shape[0],4))
    y_sqKU[:,1:4]=y_sqku
    dist_rp = rp_distance(tf.cast(sqembedKU,dtype=tf.float64) ,tf.cast(rpoints,tf.float64),1,4)     # [105,64] , rp-->[4*10,64] , dist_rp-->[105,4]
    dist_rp=tf.nn.log_softmax(dist_rp) # dist_rp-->[60,3]
    loss_rp=tf.multiply(tf.convert_to_tensor(y_sqKU, dtype=tf.float32),tf.cast(dist_rp,tf.float32)) # y_sqKU--->[60,4]
    loss_rp = -tf.reduce_mean(loss_rp, axis=-1)
    
    y_sq_Aug = y_sqK            # (Q+S)K + (S)K(Gen) + QU
    sqembed_Aug = sqembedK              # query + support knowns
    sqembed_Aug = tf.concat((sqembed_Aug,sembed_high),axis=0)  # stacking genrated known [k class]>AE o/p 
    y_sq_Aug = tf.concat((y_sq_Aug,y_support),axis=0)
    for i in range(len(y_query)):     # 90
        if np.sum(y_query[i,:])==0:     # u query 
            y_sq_Aug = tf.concat((y_sq_Aug,tf.expand_dims(y_query[i,:],axis=0)),axis=0)
            sqembed_Aug = tf.concat((sqembed_Aug,tf.expand_dims(qembed[i,:],axis=0)),axis=0)
    dists_Aug = calc_euclidian_dists(sqembed_Aug, z_prototypes) 
    log_p_y_Aug = tf.nn.log_softmax(-dists_Aug,axis=-1)
    cec_loss = -tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sq_Aug, log_p_y_Aug), axis=-1))) 
    
    #outlier detection [outlier network update]
    with tf.GradientTape() as outlier_tape:
        outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
        y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
        for i in range(len(y_sq_Aug)) :
            if i < (CS*(K*2+N)):
                y_outlier[i] = 1
        outlier_loss = 10*scce(y_outlier,outlier_pred)
    grads = outlier_tape.gradient(outlier_loss, outlier_nn.trainable_variables)
    optim1.apply_gradients(zip(grads, outlier_nn.trainable_variables))
    
    
    # outlier loss calculation for FE update
    outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
    y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
    for i in range(len(y_sq_Aug)) :
        if i < (CS*(K*2+N)):
            y_outlier[i] = 1
    outlier_loss = 10*scce(y_outlier,outlier_pred)  
    
    
    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            if outlier_index[i] == 1 :
                x = support_classes.index(ep_query_labels[i])
                if x == pred_index[i] :
                    correct_pred += 1  
            else :
                if outlier_index[i] == 0 :
                    outlier = outlier + 1  
    accuracy = correct_pred/(CS*N)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*N)
    
    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
            y_pred[i] = pred_index[i]+1
        else :
            y_pred[i] = 0
            
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))
    loss = cec_loss + outlier_loss + loss_rp 
    return loss, accuracy, outlier_det_acc, open_oa

In [None]:
# Metrics to gather
train_loss = tf.metrics.Mean(name='train_loss')
train_acc = tf.metrics.Mean(name='train_accuracy')
train_openoa = tf.metrics.Mean(name='train_openoa')
train_outlier_acc = tf.metrics.Mean(name='train_outlier_acc')

def train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,K,CS,CQ,N):
    # Forward & update gradients
    with tf.GradientTape() as tape:
        loss, accuracy, outlier_det_acc, openoa = meta_train(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,K,CS,CQ,N)
    gradients = tape.gradient(loss, FE_model.trainable_variables)
    optim2.apply_gradients(zip(gradients, FE_model.trainable_variables))
    train_loss(loss)
    train_acc(accuracy)
    train_openoa(openoa)
    train_outlier_acc(outlier_det_acc)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)

for epoch in range(5001): 
    train_loss.reset_states()  
    train_acc.reset_states()
    train_openoa.reset_states()
    train_outlier_acc.reset_states()
    for epi in range(10): 
        tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = new_episode(patches_class_ip,5,15,3,6,train_class_labels)   
        train_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,5,3,6,15)   
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', train_acc.result(), step=epoch)
        tf.summary.scalar('openoa',train_openoa.result(), step=epoch)
        tf.summary.scalar('outlier_det_acc',train_outlier_acc.result(), step=epoch)

    template = 'Epoch {}, Train Loss: {:.2f}, Train Accuracy: {:.2f}, Train Open OA: {:.2f}, Train Outlier Det. Acc: {:.2f}'
    print(template.format(epoch+1,train_loss.result(),train_acc.result()*100,train_openoa.result()*100,train_outlier_acc.result()*100))
    max_all=0.0
    if max_all<(train_acc.result()*100+train_openoa.result()*100+train_outlier_acc.result()*100):
        max_all=train_acc.result()*100+train_openoa.result()*100+train_outlier_acc.result()*100
        checkpoint.save(file_prefix = checkpoint_prefix)

## D. Meta-tuning

In [None]:
#train_class_indices = [1,2,4,5,7,9,10,11,13,14]    # 10 classes  
#train_class_labels = [2,3,5,6,8,10,11,12,14,15]
#test_class_indices = [0,3,6,8,12,15]               # 6 classes
#test_class_labels = [1,4,7,9,13,16]

train_patches_class = [patches_class_ip[i] for i in train_class_indices]        #(10)
test_patches_class = [patches_class_ip[i] for i in test_class_indices]        #(6) 
 
test_support_labels = [16,4,13]
ft_labels = [2,3,4,5,6,8,10,11,12,13,14,15,16]

In [None]:
tune_set_5 = [[] for i in range(16)]
for j in range(1,17) :
    if j in train_class_labels :
        tune_set_5[j-1] = patches_class_ip[j-1] 
    elif j in test_support_labels :
        tune_set_5[j-1] = patches_class_ip[j-1][:5,:,:,:,:] # for each class first 5 samples taken

#### Episode for fine-tuning

In [None]:
def tune_episode(patches_list,NS,NQ,CS,CQ,class_labels) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = list(np.random.choice(class_labels,CQ,replace=False))  # Randomly choice 6 Query Classes
    support_classes = list(np.random.choice(selected_classes,CS,replace=False))  # Randomly choice 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes

In [None]:
tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = tune_episode(tune_set_5,1,4,3,6,ft_labels)
print(tsupport_patches.shape,tquery_patches.shape,support_labels,query_labels)

In [None]:
# Checkpoints to save meta-tuning updates
checkpoint_dir_tune = '/content/Tune_Rp'
checkpoint_prefix_tune = os.path.join(checkpoint_dir_tune, "ckpt")
checkpoint_tune = tf.train.Checkpoint(optim1=optim1, optim2=optim2, optim_d_low=optim_d_low, optim_g_low=optim_g_low, optim_d_high=optim_d_high,
                                 optim_g_high=optim_g_high, FE_model = FE_model,
                                 generator_nn_low=generator_nn_low, generator_nn_high=generator_nn_high,
                                 dis_nn_low=dis_nn_low, dis_nn_high=dis_nn_high,
                                 outlier_nn=outlier_nn)

ngan = 5 
emb_dim = 64
tK = 1
tN = 4

In [None]:
def meta_tune(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,tK,CS,CQ,tN):#5,3,6,15     
    outlier = 0
    sembed = FE_model(ep_class_images)                             # [15, 64]        
    qembed = FE_model(ep_query_images)                             # [90, 64]
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)  # (90, 3) 
    y_true = np.zeros(len(ep_query_labels)) #for storing labels of classes, 0 for unseen; 1,2,3 for the three classes
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
        if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.      
    
    for _ in range(ngan):    
        sembed_low,vector=gan_reptile(sembed,ep_class_labels,generator_nn_low,dis_nn_low,optim_d_low,optim_g_low,stdev=0.2,alpha1=0.003)
        sembed_high=gan_reptile_high(sembed,sembed_low,vector,ep_class_labels,generator_nn_high,dis_nn_high,optim_d_high,optim_g_high,stdev=1.0,alpha1=0.003)
    
    z_proto_low = tf.reshape(sembed_low,[CS, tK, sembed_low.shape[-1]])      # [3, 1, 64]
    z_prototypes = tf.reshape(sembed,[CS, tK, sembed.shape[-1]])           # [3, 1, 64]
    z_prototypes = tf.math.reduce_mean((z_prototypes + z_proto_low), axis=1) # [3, 1, 64]

    rp_proto_low = tf.reshape(sembed,[CS, tK, sembed.shape[-1]])
    rp_proto_low = tf.reduce_mean(rp_proto_low, axis=1)

    rp_proto_high = tf.reshape(qembed, [CQ, tN, qembed.shape[-1]])
    rp_proto_high = tf.reduce_mean(rp_proto_high, axis=1)
    rp_proto_high = tf.reduce_mean(rp_proto_high, axis=0, keepdims=True)

    rpoints = tf.concat([rp_proto_low,rp_proto_high], 0)
    

    # Vautoencoder Loss on Query + Support
    rec_kl_loss = 0
    clf_loss = 0
    sqembedK = np.zeros((CS*(tN+tK),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(tN+tK),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
        if np.sum(y_query[i,:])==1:     # k query 
            y_sqK[j,:] = y_query[i,:]
            sqembedK[j,:] = qembed[i,:]
            j = j + 1
    for i in range(len(sembed)) :
        sqembedK[j,:] = sembed[i,:]
        y_sqK[j,:] = y_support[i,:]
        j = j + 1
    sqembedKU = tf.concat([sembed, qembed], 0) 
    y_sqku = np.concatenate((y_support, y_query), 0)
    y_sqKU = np.zeros((y_sqku.shape[0],4))
    y_sqKU[:,1:4]=y_sqku
    dist_rp = rp_distance(tf.cast(sqembedKU,dtype=tf.float64) ,tf.cast(rpoints,tf.float64),1,4)     # [105,64] , rp-->[4*10,64] , dist_rp-->[105,4]
    dist_rp=tf.nn.log_softmax(dist_rp) # dist_rp-->[60,3]
    loss_rp=tf.multiply(tf.convert_to_tensor(y_sqKU, dtype=tf.float32),tf.cast(dist_rp,tf.float32)) # y_sqKU--->[60,4]
    loss_rp = -tf.reduce_mean(loss_rp, axis=-1)

    
    # Query set Augmentation((S + QK)(Original + Gen)) [CEC loss for FE]
    y_sq_Aug = y_sqK            # (Q+S)K + (S)K(Gen) + QU
    sqembed_Aug = sqembedK              # query + support knowns
    sqembed_Aug = tf.concat((sqembed_Aug,sembed_high),axis=0)  # stacking genrated known [k class]>AE o/p  
    y_sq_Aug = tf.concat((y_sq_Aug,y_support),axis=0)
    for i in range(len(y_query)):     # 90
        if np.sum(y_query[i,:])==0:     # u query 
            y_sq_Aug = tf.concat((y_sq_Aug,tf.expand_dims(y_query[i,:],axis=0)),axis=0)
            sqembed_Aug = tf.concat((sqembed_Aug,tf.expand_dims(qembed[i,:],axis=0)),axis=0)
    dists_Aug = calc_euclidian_dists(sqembed_Aug, z_prototypes) 
    log_p_y_Aug = tf.nn.log_softmax(-dists_Aug,axis=-1)
    cec_loss = -tf.reduce_mean((tf.reduce_sum(tf.multiply(y_sq_Aug, log_p_y_Aug), axis=-1))) 
    
    #outlier detection [outlier network update]
    with tf.GradientTape() as outlier_tape:
        outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
        y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
        for i in range(len(y_sq_Aug)) :
            if i < (CS*(tK*2+tN)):
                y_outlier[i] = 1
        outlier_loss = 10*scce(y_outlier,outlier_pred)
    grads = outlier_tape.gradient(outlier_loss, outlier_nn.trainable_variables)
    optim1.apply_gradients(zip(grads, outlier_nn.trainable_variables))

    # outlier loss calculation for FE update
    outlier_pred = outlier_nn(dists_Aug) #0 for unseen, 1 for seen
    y_outlier = np.zeros(len(y_sq_Aug)) #labels for outliers, 0 - unseen, 1 - seen
    for i in range(len(y_sq_Aug)) :
        if i < (CS*(tK*2+tN)):
            y_outlier[i] = 1
    outlier_loss = 10*scce(y_outlier,outlier_pred)  

    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            if outlier_index[i] == 1 :
                x = support_classes.index(ep_query_labels[i])
                if x == pred_index[i] :
                    correct_pred += 1  
        else :
            if outlier_index[i] == 0 :
                outlier = outlier + 1  
    accuracy = correct_pred/(CS*tN)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*tN)


    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
            y_pred[i] = pred_index[i]+1
        else :
            y_pred[i] = 0
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))
    
   
    loss = cec_loss + outlier_loss + loss_rp
    return loss, accuracy, outlier_det_acc, open_oa    # scalar, scalar

In [None]:
# Metrics to gather
tune_loss = tf.metrics.Mean(name='tune_loss')
tune_acc = tf.metrics.Mean(name='tune_accuracy')
tune_open_acc = tf.metrics.Mean(name='tune_open_accuracy')
tune_outlier_acc = tf.metrics.Mean(name='tune_outlier_accuracy')

def tune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,CS,CQ,tN):
    # Forward & update gradients
    with tf.GradientTape() as tape:
        loss, accuracy, outlier_det_acc, open_oa = meta_tune(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,CS,CQ,tN)
    gradients = tape.gradient(loss, FE_model.trainable_variables)
    optim2.apply_gradients(zip(gradients, FE_model.trainable_variables))
    # Log loss and accuracy for step
    tune_loss(loss)
    tune_acc(accuracy)
    tune_open_acc(open_oa)
    tune_outlier_acc(outlier_det_acc)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tune_log_dir = 'logs/gradient_tape/' + current_time + '/tune'
tune_summary_writer = tf.summary.create_file_writer(tune_log_dir)
        
for epoch in range(1001): 
    tune_loss.reset_states()  
    tune_acc.reset_states()
    tune_open_acc.reset_states()
    tune_outlier_acc.reset_states()
    for epi in range(10): 
        tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = tune_episode(tune_set_5,1,4,3,6,ft_labels)   
        tune_step(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,tK,3,6,tN)
    
    with tune_summary_writer.as_default():
        tf.summary.scalar('loss', tune_loss.result(), step=epoch)
        tf.summary.scalar('accuracy', tune_acc.result(), step=epoch)
        tf.summary.scalar('Open_accuracy', tune_open_acc.result(), step=epoch)
        tf.summary.scalar('Outlier_accuracy', tune_outlier_acc.result(), step=epoch)

    template = 'Epoch {}, Tune Loss: {:.2f}, Tune Accuracy: {:.2f}, Open Accuracy: {:.2f},Outlier Accuracy: {:.2f}'
    print(template.format(epoch+1,tune_loss.result(),tune_acc.result()*100,tune_open_acc.result()*100,tune_outlier_acc.result()*100))

    max_all=0.0
    if max_all<(tune_acc.result()*100+tune_open_acc.result()*100+tune_outlier_acc.result()*100):
        max_all=tune_acc.result()*100+tune_open_acc.result()*100+tune_outlier_acc.result()*100
        checkpoint_tune.save(file_prefix = checkpoint_prefix_tune)

## E. Meta-testing

In [None]:
def meta_test(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,support_classes,K,CS,CQ,N):#5,3,6,15     
    outlier = 0
    sembed = FE_model(ep_class_images)                             # [15, 64]        
    qembed = FE_model(ep_query_images)                             # [90, 64]
    y_query = np.asarray(np.zeros((len(ep_query_images),CS)),dtype=np.float32)  # (90, 3) 
    y_true = np.zeros(len(ep_query_labels)) #for storing labels of classes, 0 for unseen; 1,2,3 for the three classes
    y_auc = np.zeros((len(ep_query_labels))) #for storing labels, 1 for seen, and 0 for unseen
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            x = support_classes.index(ep_query_labels[i])
            y_query[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
            y_true[i] = x+1
            y_auc[i] = 1
    y_support = np.asarray(np.zeros((len(ep_class_images),CS)),dtype=np.float32)  # (15, 3) 
    for i in range(len(ep_class_labels)) :
        if ep_class_labels[i] in support_classes :
            x = support_classes.index(ep_class_labels[i])
            y_support[i][x] = 1.                                      # [[0., 0., 1.], [0., 0., 0.], ... (90,3)
    
    z_prototypes = tf.reshape(sembed,[CS, K, sembed.shape[-1]])           # [3, 5, 64]
    z_prototypes = tf.math.reduce_mean(z_prototypes, axis=1)        # [3, 64]   

    sqembedK = np.zeros((CS*(N+K),emb_dim)) #known query then support samples
    y_sqK = np.asarray(np.zeros((CS*(N+K),CS)),dtype=np.float32) # QK + SK
    j = 0
    for i in range(len(y_query)):     # 90
        if np.sum(y_query[i,:])==1:     # k query 
            y_sqK[j,:] = y_query[i,:]
            sqembedK[j,:] = qembed[i,:]
            j = j + 1
    for i in range(len(sembed)) :
        sqembedK[j,:] = sembed[i,:]
        y_sqK[j,:] = y_support[i,:]
        j = j + 1
   
    #accuracy calculation
    correct_pred = 0
    dists = calc_euclidian_dists(qembed, z_prototypes)               # [90, 3]  
    outlier_prob = outlier_nn(dists)
    outlier_prob1 = outlier_prob
    outlier_index = tf.argmax(outlier_prob,axis=-1)
    predictions = tf.nn.softmax(-dists, axis=-1)
    pred = predictions
    pred2 = predictions
    pred_index = tf.argmax(predictions,axis=-1)
    for i in range(len(ep_query_labels)) :
        if ep_query_labels[i] in support_classes :
            if outlier_index[i] == 1 :
                x = support_classes.index(ep_query_labels[i])
                if x == pred_index[i] :
                    correct_pred += 1  
        else :
            if outlier_index[i] == 0 :
                outlier = outlier + 1  
    accuracy = correct_pred/(CS*N)     # scalar      
    outlier_det_acc = outlier/((CQ-CS)*N)
    y_score = np.zeros((len(ep_query_labels)))
    for i in range(len(ep_query_labels)) :
        y_score[i] = outlier_prob1[i,1]
    auc = sklearn.metrics.roc_auc_score(y_auc, y_score)


    #open oa
    y_pred = np.zeros((len(ep_query_labels))) 
    for i in range(len(ep_query_labels)) :
        if outlier_index[i] == 1 :
            y_pred[i] = pred_index[i]+1
        else :
            y_pred[i] = 0
    cm = confusion_matrix(y_true,y_pred)
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    open_oa = (sum(TP)+sum(TN))/(sum(TP)+sum(TN)+sum(FP)+sum(FN))    
    return accuracy, open_oa, outlier_det_acc, auc    # scalar, scalar

In [None]:
def test_episode(patches_list,NS,NQ,CS,CQ) :  # NS 5,NQ 15,CS 3,CQ 6
    selected_classes = test_class_labels  # 6 Query Classes
    support_classes = test_support_labels  # 3 Support Classes from Q
    
    tquery_patches,tsupport_patches = [],[]
    query_labels,support_labels = [],[]
    
    for x in support_classes :      #3
        sran_indices = np.random.choice(patches_list[x-1].shape[0],NS,replace=False) # K=5 img nos from 20 per class for Support
        support_patches = patches_list[x-1][sran_indices,:,:,:,:]  # for x class those 5 nos
        tsupport_patches.extend(support_patches)               # 3 class * 5 patch per class = 15 patches
        for i in range(NS) :
            support_labels.append(x)                           # 3 class *5 nos = 15 nos
        
    for x in selected_classes :     #6
        qran_indices = np.random.choice(patches_list[x-1].shape[0],NQ,replace=False) # NQ=15 img nos from 20 per class for Query
        query_patches = patches_list[x-1][qran_indices,:,:,:,:]    # for x class those 15 nos
        tquery_patches.extend(query_patches)                   # 6 class * 15 patch per class = 90 patches
        for i in range(NQ) :
            query_labels.append(x)                             # 6 class * 15 nos = 90 nos
    
    temp1 = list(zip(tquery_patches, query_labels)) 
    random.shuffle(temp1)        # By Doing Shuffling, Support, Query Same class combination got mismatched - mitigated by support index
    tquery_patches, query_labels = zip(*temp1)
    
    tquery_patches = tf.convert_to_tensor(np.reshape(np.asarray(tquery_patches),(CQ*NQ,11,11,30,1)),dtype=tf.float32)
    tsupport_patches = tf.convert_to_tensor(np.reshape(np.asarray(tsupport_patches),(CS*NS,11,11,30,1)),dtype=tf.float32)
    return tquery_patches, tsupport_patches, query_labels, support_labels, support_classes

In [None]:
total_acc = 0 
total_open_oa = 0
total_outlier_acc = 0 
rec_k = 0
rec_u = 0
f1score = 0 
total_auc = 0
emb_dim = 64
tepochs = 100
K = 5
N = 15

for i in range(tepochs) :
    tquery_patches, tsupport_patches, query_labels, support_labels, support_classes = test_episode(patches_class_ip,5,15,3,6)   
    accuracy, open_oa, outlier_det_acc, auc = meta_test(tsupport_patches,tquery_patches,support_labels,query_labels,support_classes,5,3,6,15)
    total_acc = total_acc + accuracy
    total_open_oa = total_open_oa + open_oa
    total_outlier_acc = total_outlier_acc + outlier_det_acc
    total_auc = total_auc + auc
print('accuracy',total_acc*100/tepochs)
print('Outlier detection accuracy', (total_outlier_acc*100/tepochs))
print('open oa',total_open_oa*100/tepochs)
print('auc',total_auc/tepochs)