In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import tensorflow.compat.v2 as tf
import cv2
import os
import numpy as np
from tensorflow.keras.layers import Conv2D,BatchNormalization, MaxPool2D, Activation, Flatten, Dense
from tensorflow.keras import regularizers
import random

In [0]:
!unzip /content/drive/My\ Drive/mini-imagenet.zip

In [0]:
images_path = '/content/images'

In [0]:
classes=[]   # list of 100 class labels
class_images = [[] for x in range(100)]  # list of lists(600 images for each class label)  
for file in os.listdir(images_path) :
   if file.endswith('.jpg') :
      if (file[:9] not in classes) :
        classes.append(file[:9])
for file in os.listdir(images_path) :
  if file.endswith('.jpg') :
    i = classes.index(file[:9])
    class_images[i].append(file)        
  

In [0]:
print(len(classes))
print(classes[0])
print(class_images[0])

In [0]:
train_classes = classes[:64]
validation_classes = classes[64:81]
test_classes = classes[81:101]

In [0]:
K=1
C=5
N=15

In [0]:
def new_episode(train_classes1,class_images,K,C,N) :
  ep_class_images=[]  #total CK images with every K images having a class C
  ep_query_images=[]  #total CN images
  ep_class_labels = []   #C labels
  ep_query_labels = []   #C*N labels for each query image
  selected_classes = np.random.choice(train_classes1, size=C)
  for x in selected_classes :
    ep_images=[]
    ep_classes=[]
    i = classes.index(x)
    train_y = np.random.choice(class_images[i],K)
    query_y = np.random.choice(class_images[i],N)
    for img in train_y : 
      ep_images.append(cv2.resize(cv2.imread(os.path.join(images_path, img)),(84,84)))
      ep_classes.append(x)
    ep_class_images.extend(ep_images)
    ep_class_labels.extend(ep_classes)
    ep_images=[]
    ep_classes=[]
    for img in query_y :
      ep_images.append(cv2.resize(cv2.imread(os.path.join(images_path,img)),(84,84)))
      ep_classes.append(x)
    ep_query_images.extend(ep_images)
    ep_query_labels.extend(ep_classes) 
    
  temp1 = list(zip(ep_class_images, ep_class_labels)) 
  random.shuffle(temp1) 
  ep_class_images, ep_class_labels = zip(*temp1)
  temp2 = list(zip(ep_query_images, ep_query_labels)) 
  random.shuffle(temp2) 
  ep_query_images, ep_query_labels = zip(*temp2)  
  ep_class_images = np.asarray(ep_class_images, dtype=np.double)/255
  ep_class_images = tf.convert_to_tensor(np.reshape(ep_class_images,(K*C,84,84,3)),dtype=tf.float32)
  ep_query_images = np.asarray(ep_query_images, dtype=np.double)/255
  ep_query_images = tf.convert_to_tensor(np.reshape(ep_query_images,(N*C,84,84,3)),dtype=tf.float32)
  return ep_class_images, ep_class_labels, ep_query_images, ep_query_labels
  


In [0]:
ep_class_images, ep_class_labels, ep_query_images, ep_query_labels = new_episode(train_classes,class_images,1,5,15)

In [0]:
print(len(ep_class_images))

5


In [0]:
class Embedding_module(tf.keras.Model) :
  def __init__(self) :
    super(Embedding_module,self).__init__()
    self.convi = Conv2D(64, (3,3),strides=(1,1),input_shape=(84,84,3))
    self.conv = Conv2D(64, (3,3), strides=(1,1))
    self.bn = BatchNormalization()
    self.max_pool = MaxPool2D()
    self.activation = Activation(activation = 'relu')
  def call(self, x,training=True) :
    out1 = self.convi(x)
    out1 = self.bn(out1)
    out1 = self.activation(out1)
    out2 = self.max_pool(out1)
    out2 = self.conv(out2)
    out2 = self.bn(out2)
    out2 = self.activation(out2)
    out3 = self.max_pool(out2)
    out3 = self.conv(out3)
    out3 = self.bn(out3)
    out4 = self.activation(out3)
    out4 = self.conv(out4)
    out4 = self.bn(out4)
    out4 = self.activation(out4)
    return out4




In [0]:
embedding_out = Embedding_module()

In [0]:
'''out_class_fm = embedding_out(ep_class_images)
out_query_fm = embedding_out(ep_query_images)'''

'out_class_fm = embedding_out(ep_class_images)\nout_query_fm = embedding_out(ep_query_images)'

In [0]:
def concatenation(out_class_fm,out_query_fm,K,C,N) : #concatenation of the embeddings
  emb_classes=[]
  concat_out=[]
  for i in range(C) :
    emb_classes.append(tf.tile(tf.expand_dims(out_class_fm[i,:,:,:],0),multiples=[C*N,1,1,1]))
  for i in range(C) :
    concat_out.append(tf.concat([emb_classes[i],out_query_fm],axis=-1))
  return emb_classes, concat_out

In [0]:
class Relation_score(tf.keras.Model) :
  def __init__(self) :
    super(Relation_score, self).__init__()
    self.convi = Conv2D(64, (3,3),strides=(1,1),input_shape=(15,15,128))
    self.conv = Conv2D(64, (3,3), strides=(1,1),padding = 'same')
    self.bn = BatchNormalization()
    self.max_pool = MaxPool2D()
    self.flatten = Flatten()
    self.dense1 = Dense(8,activation='relu')
    self.dense2 = Dense(1,activation='sigmoid')
    self.activation = Activation('relu')  
  def call(self, x, training=True) :
    out1 = self.convi(x)
    out1 = self.bn(out1)
    out1 = self.activation(out1)
    out2 = self.max_pool(out1)
    out2 = self.conv(out2)
    out2 = self.bn(out2)
    out2 = self.activation(out2)
    out3 = self.max_pool(out2)
    out3 = self.flatten(out3)
    out3 = self.dense1(out3)
    out4 = self.dense2(out3)
    return out4



In [0]:
rs = Relation_score()

In [0]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

In [0]:
def train_step(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,K,C,N) :
  loss=0
  with tf.GradientTape() as tape :
    out_class_fm = embedding_out(ep_class_images)
    out_query_fm = embedding_out(ep_query_images)
    emb_classes, econcat_out = concatenation(out_class_fm,out_query_fm,K,C,N)
    relation_scores = tf.concat(values=[rs(econcat_out[0]),rs(econcat_out[1]), rs(econcat_out[2]), rs(econcat_out[3]), rs(econcat_out[4])],axis=1)
    real_one_hot = np.zeros([75,5],dtype=np.float32)
    for i in range(C*N) :  
      x = ep_class_labels.index(ep_query_labels[i])
      real_one_hot[i][x]=1.
    real_one_hot = tf.convert_to_tensor(real_one_hot,dtype=tf.float32)
    loss = tf.reduce_sum(tf.square(tf.subtract(real_one_hot,relation_scores)))
  variables = rs.trainable_variables + embedding_out.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

In [0]:
checkpoint_dir = '/content/drive/My Drive/training_checkpoints_ones_updated_random'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 embedding_out=embedding_out,
                                 rs=rs)

In [0]:
epochs = 80000
for i in range(epochs) :
  ep_class_images, ep_class_labels, ep_query_images, ep_query_labels = new_episode(train_classes,class_images,1,5,15)
  loss = train_step(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,1,5,15)
  if i %1000 == 0 :
    print('epoch',i,loss)
  if i % 5000 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

In [0]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fa10a7e3b38>

In [0]:
def test_step(ep_class_images, ep_query_images,ep_class_labels,ep_query_labels, K, C, N) :
    correct_labels = []
    correct_pred = []
    out_class_fm = embedding_out(ep_class_images)
    out_query_fm = embedding_out(ep_query_images)
    emb_classes, econcat_out = concatenation(out_class_fm,out_query_fm,K,C,N)
    relation_scores = tf.concat(values=[rs(econcat_out[0]),rs(econcat_out[1]), rs(econcat_out[2]), rs(econcat_out[3]), rs(econcat_out[4])],axis=1)
    predicted_labels = tf.math.argmax(relation_scores, axis =1)
    real_one_hot = np.zeros([C*N,1],dtype=np.int64)
    for i in range(C*N) :  
      x = ep_class_labels.index(ep_query_labels[i])
      real_one_hot[i] = x
    real_one_hot = tf.convert_to_tensor(real_one_hot,dtype=tf.int64)
    
    for i, j in zip(predicted_labels, real_one_hot) :
       if (i == j) :
         correct_labels.append(i)
    x = len(correct_labels)
    return x

In [0]:
epochs = 600
correct = 0
for i in range(epochs) :
  ep_class_images, ep_class_labels, ep_query_images, ep_query_labels = new_episode(test_classes,class_images,1,5,15)
  x = test_step(ep_class_images,ep_query_images,ep_class_labels,ep_query_labels,1,5,15)
  correct += x
  print('epoch',i)
print('accuracy',correct/(6*75))

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49
epoch 50
epoch 51
epoch 52
epoch 53
epoch 54
epoch 55
epoch 56
epoch 57
epoch 58
epoch 59
epoch 60
epoch 61
epoch 62
epoch 63
epoch 64
epoch 65
epoch 66
epoch 67
epoch 68
epoch 69
epoch 70
epoch 71
epoch 72
epoch 73
epoch 74
epoch 75
epoch 76
epoch 77
epoch 78
epoch 79
epoch 80
epoch 81
epoch 82
epoch 83
epoch 84
epoch 85
epoch 86
epoch 87
epoch 88
epoch 89
epoch 90
epoch 91
epoch 92
epoch 93
epoch 94
epoch 95
epoch 96
epoch 97
epoch 98
epoch 99
epoch 100
epoch 101
epoch 102
epoch 103
epoch 104
epoch 105
epoch 106
epoch 107
epoch 108
epoch 109
epoch 110
