In [1]:
import os
import cv2
import numpy as np
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Lambda, Dense, Dropout, Conv2D, MaxPooling2D, Flatten,Activation,BatchNormalization
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.losses import CategoricalCrossentropy

from tensorflow.keras.applications.efficientnet import preprocess_input,EfficientNetB0

from tensorflow.keras import callbacks as cb
from tensorflow.keras import backend as K
from tensorflow.keras import utils
from sklearn.model_selection import train_test_split

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth( device=gpu, enable=True)

In [2]:
all_classes = glob('/home/jovyan/data/fungi/images/*')
source_classes,target_classes=train_test_split(all_classes,test_size=0.1)
sorce_len=len(source_classes)
target_len=len(target_classes)
print(f"total {len(all_classes)} classes=source {sorce_len} + target {target_len} classes")

total 1394 classes=source 1254 + target 140 classes


In [3]:
W,H,CH = 224,224,3
# W,H,CH = 64,64,3
def load_img(path, width = W):
#     img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
    img = tf.io.read_file(path)
#     img = tf.image.decode_jpeg(img, channels=CH)
    img = tf.image.decode_jpeg(img, channels=CH)/255
    img = img.numpy()
    
    shape_dst = np.min(img.shape[:2])
    oh = (img.shape[0] - shape_dst) // 2
    ow = (img.shape[1] - shape_dst) // 2
    center_square = np.array([width,width])// 2
    new_size=(width,width)
    
    # cropping + resize
    img = img[oh:oh + shape_dst, ow:ow + shape_dst]
    img=cv2.resize(img, new_size)
    
    # Random affine+rotation
#     random_state = np.random.RandomState(seed=None)
#     alpha_affine=img.shape[0] * 0.1
#     pts1 = np.float32([center_square + width, 
#                        [center_square[0]+width, center_square[1]-width],
#                        center_square - width])
#     pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
#     M = cv2.getAffineTransform(pts1, pts2)
#     rot_mat = cv2.getRotationMatrix2D(tuple(center_square),random_state.rand()*360., 1.0)
#     img = cv2.warpAffine(img, M, new_size, borderMode=cv2.BORDER_REPLICATE)
#     img = cv2.warpAffine(img, rot_mat, new_size, borderMode=cv2.BORDER_REPLICATE)
    return tf.constant(img)

In [4]:
WAYS = 3
SHOTS=5
QUERIES=1
BATCH_SIZE=4

In [5]:
min([len(glob(sdir+'/*.JPG')) for sdir in source_classes])

6

In [6]:
## exclude classes with too few examples
source_classes=[sdir for sdir in source_classes if len(glob(sdir+'/*.JPG'))>SHOTS+QUERIES]

sorce_len=len(source_classes)
print(f"source {sorce_len} + target {target_len} classes")

source 1242 + target 140 classes


In [7]:
min([len(glob(sdir+'/*.JPG')) for sdir in source_classes])

7

In [8]:
map_fun=lambda string: tf.py_function(func=load_img,inp=[string], Tout=tf.float32)
all_sub = [
    tf.data.Dataset.list_files(sc+'/*.JPG', shuffle=True)
    .map(map_fun)
    for sc in source_classes
]

def gen():
    supports = []
    querys = []
    order=np.random.permutation(len(all_sub))
    for tasks in range(len(all_sub)//WAYS):
        picked=[all_sub[tt] for tt in order[WAYS*tasks:WAYS*(tasks+1)]]
        support = tf.concat(
            [
                next(
                    iter(
                        sub.batch(SHOTS).prefetch(SHOTS)
                    )
                    ) for sub in picked
            ]
            , axis=0)
        idxs=np.random.choice(range(WAYS), size=QUERIES, replace=False)
        query = tf.concat(
            [
                next(
                    iter(
                        picked[idx].batch(1).prefetch(1)
                    )
                    ) for idx in idxs
            ]
            , axis=0)
        yield tf.concat([support, query], axis=0), tuple([keras.utils.to_categorical(idx,num_classes=WAYS) for idx in idxs])
dd = tf.data.Dataset.from_generator(gen,
                                    output_types=(tf.float32,tuple([tf.float32]*QUERIES)),
                                    output_shapes=((WAYS*SHOTS+QUERIES,W,H,CH),tuple([WAYS]*QUERIES))
                                   ).batch(BATCH_SIZE)

In [9]:
def conv_net(input_shape):
    convnet = Sequential()
    for i in range(4):
        convnet.add(Conv2D(64,(3,3),padding='valid',input_shape=input_shape))
        convnet.add(BatchNormalization())
        convnet.add(Activation('relu'))
        convnet.add(MaxPooling2D())
    convnet.add(Flatten())
    return convnet
def pretrain_net(input_shape):
    base_model = EfficientNetB0(weights=None, include_top=False)
    x_in=Input(shape=(W,H,CH))
    x=preprocess_input(x_in)
    out=Flatten()(base_model(x))
    model = Model(inputs=x_in, outputs=out)
    return model
def euclidean_distance(f_1,f_2):
    """
    Euclidean distance loss
    https://en.wikipedia.org/wiki/Euclidean_distance
    :param y_true: TensorFlow/Theano tensor
    :param y_pred: TensorFlow/Theano tensor of the same shape as y_true
    :return: float
    """
    return K.sqrt(K.sum(K.square(f_1 - f_2), axis=-1))

In [10]:
base_dim = (W,H,CH)
base_network = pretrain_net(base_dim)
# Query feature
x_in=Input(shape=(WAYS*SHOTS+QUERIES,W,H,CH))
latent_s=[base_network(x_in[:,ii]) for ii in range(WAYS*SHOTS)]
latent_q=[base_network(x_in[:,WAYS*SHOTS+ii]) for ii in range(QUERIES)]

y=list()
for qq in range(QUERIES):
    dist_scores=list()
    for ww in range(WAYS):
        latent_proto=tf.reduce_mean(tf.stack(latent_s[ww*SHOTS:(ww+1)*SHOTS],axis=-1),axis=-1)
        dist_avg=euclidean_distance(latent_q[qq],latent_proto)
        dist_scores.append(dist_avg)
    y.append(tf.nn.softmax(-tf.stack(dist_scores,axis=-1),axis=-1))

model = Model(inputs=x_in, outputs=tuple(y))

In [13]:
lr=0.001
def scheduler(epoch):
    global lr
    if epoch % 3 == 0:
        lr /= 2
    return lr
reduce_lr = cb.ReduceLROnPlateau(monitor='loss', factor=0.4,patience=2, min_lr=1e-8)
lr_sched = cb.LearningRateScheduler(scheduler)
tensorboard = cb.TensorBoard()
opt = tf.keras.optimizers.Adam(lr=0.002)
model.compile(loss=CategoricalCrossentropy(), optimizer=opt, metrics=['categorical_accuracy'])

In [14]:
# %%time
model.fit(dd, epochs=1000, verbose=1,workers=4, callbacks=[reduce_lr,lr_sched, tensorboard])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

<tensorflow.python.keras.callbacks.History at 0x7feaf56177d0>