In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm.notebook import tqdm

from sklearn.utils.class_weight import compute_class_weight

In [2]:
base_model_name = 'inception' 
base_model = tf.keras.models.load_model(f'model/{base_model_name}')

In [3]:
base_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 299, 299, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 149, 149, 32  864         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 149, 149, 32  96         ['conv2d[0][0]']                 
 alization)                     )                                                             

 batch_normalization_5 (BatchNo  (None, 35, 35, 64)  192         ['conv2d_5[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 35, 35, 64)  192         ['conv2d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 batch_normalization_10 (BatchN  (None, 35, 35, 96)  288         ['conv2d_10[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_11 (BatchN  (None, 35, 35, 32)  96          ['conv2d_11[0][0]']              
 ormalizat

                                                                                                  
 batch_normalization_22 (BatchN  (None, 35, 35, 64)  192         ['conv2d_22[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_22 (Activation)     (None, 35, 35, 64)   0           ['batch_normalization_22[0][0]'] 
                                                                                                  
 conv2d_20 (Conv2D)             (None, 35, 35, 48)   13824       ['mixed1[0][0]']                 
                                                                                                  
 conv2d_23 (Conv2D)             (None, 35, 35, 96)   55296       ['activation_22[0][0]']          
                                                                                                  
 batch_nor

                                                                                                  
 max_pooling2d_2 (MaxPooling2D)  (None, 17, 17, 288)  0          ['mixed2[0][0]']                 
                                                                                                  
 mixed3 (Concatenate)           (None, 17, 17, 768)  0           ['activation_26[0][0]',          
                                                                  'activation_29[0][0]',          
                                                                  'max_pooling2d_2[0][0]']        
                                                                                                  
 conv2d_34 (Conv2D)             (None, 17, 17, 128)  98304       ['mixed3[0][0]']                 
                                                                                                  
 batch_normalization_34 (BatchN  (None, 17, 17, 128)  384        ['conv2d_34[0][0]']              
 ormalizat

                                                                  'activation_39[0][0]']          
                                                                                                  
 conv2d_44 (Conv2D)             (None, 17, 17, 160)  122880      ['mixed4[0][0]']                 
                                                                                                  
 batch_normalization_44 (BatchN  (None, 17, 17, 160)  480        ['conv2d_44[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_44 (Activation)     (None, 17, 17, 160)  0           ['batch_normalization_44[0][0]'] 
                                                                                                  
 conv2d_45 (Conv2D)             (None, 17, 17, 160)  179200      ['activation_44[0][0]']          
          

 ormalization)                                                                                    
                                                                                                  
 activation_54 (Activation)     (None, 17, 17, 160)  0           ['batch_normalization_54[0][0]'] 
                                                                                                  
 conv2d_55 (Conv2D)             (None, 17, 17, 160)  179200      ['activation_54[0][0]']          
                                                                                                  
 batch_normalization_55 (BatchN  (None, 17, 17, 160)  480        ['conv2d_55[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_55 (Activation)     (None, 17, 17, 160)  0           ['batch_normalization_55[0][0]'] 
          

                                                                                                  
 batch_normalization_65 (BatchN  (None, 17, 17, 192)  576        ['conv2d_65[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_65 (Activation)     (None, 17, 17, 192)  0           ['batch_normalization_65[0][0]'] 
                                                                                                  
 conv2d_61 (Conv2D)             (None, 17, 17, 192)  147456      ['mixed6[0][0]']                 
                                                                                                  
 conv2d_66 (Conv2D)             (None, 17, 17, 192)  258048      ['activation_65[0][0]']          
                                                                                                  
 batch_nor

                                                                                                  
 conv2d_70 (Conv2D)             (None, 17, 17, 192)  147456      ['mixed7[0][0]']                 
                                                                                                  
 conv2d_74 (Conv2D)             (None, 17, 17, 192)  258048      ['activation_73[0][0]']          
                                                                                                  
 batch_normalization_70 (BatchN  (None, 17, 17, 192)  576        ['conv2d_70[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_74 (BatchN  (None, 17, 17, 192)  576        ['conv2d_74[0][0]']              
 ormalization)                                                                                    
          

 batch_normalization_76 (BatchN  (None, 8, 8, 320)   960         ['conv2d_76[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_78 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_78[0][0]'] 
                                                                                                  
 activation_79 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_79[0][0]'] 
                                                                                                  
 activation_82 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_82[0][0]'] 
                                                                                                  
 activation_83 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_83[0][0]'] 
          

                                                                                                  
 activation_91 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_91[0][0]'] 
                                                                                                  
 activation_92 (Activation)     (None, 8, 8, 384)    0           ['batch_normalization_92[0][0]'] 
                                                                                                  
 batch_normalization_93 (BatchN  (None, 8, 8, 192)   576         ['conv2d_93[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_85 (Activation)     (None, 8, 8, 320)    0           ['batch_normalization_85[0][0]'] 
                                                                                                  
 mixed9_1 

In [4]:
local_model = tf.keras.Model(inputs=base_model.inputs, 
                             outputs=base_model.get_layer(name = 'conv2d').output, 
                             name="local_feature")

In [5]:
global_model = tf.keras.Model(inputs = base_model.inputs, 
                             outputs=base_model.get_layer(name = 'global_average_pooling2d').output,
                             name = 'global_feature')

In [6]:
def get_class_label_weight(df):
    labels = list(df['dx'])
    
    weights = compute_class_weight(class_weight = 'balanced', 
                         classes = np.unique(labels),
                         y = labels)

    labels = np.unique(labels)

    class_label = {name : idx for idx, name in enumerate(labels)}
    weights = {idx : weights[idx] for idx, name in enumerate(labels)}
    
    return class_label, weights

In [7]:
def get_img(img_path, image_size=299):
    preprocess_input = tf.keras.applications.inception_v3.preprocess_input
    img = tf.keras.preprocessing.image.load_img(img_path, 
                                                target_size = (image_size,image_size))
    img = tf.keras.utils.img_to_array(img) 
    img = preprocess_input(img)
    return img

In [8]:
train_df = pd.read_csv('data/train_truth.csv')

In [9]:
class_label, weights = get_class_label_weight(train_df)

In [10]:
shuffle_df = lambda x : x.sample(frac=1) 

In [11]:
def get_local_feature(df, label, num = 5):
    img_id = df[df['dx'] == label].sample(num)['image_id']
    l = []
    for img in img_id:
        path = f'data/train/{label}/{img}.jpg'
        l.append(get_img(path))
    return np.array(l)

In [12]:
class Causal_Model(tf.keras.Model):
    def __init__(self, num, num_classes=7, class_weights=None):
        super().__init__()
        self.num = num

        # Class Weight
        self.class_weights = class_weights 
        
        # Pooling for local features for each image
        self.pooled = tf.keras.layers.GlobalAveragePooling2D()

        # Post-Concatenation
        self.dense_1 = tf.keras.layers.Dense(1024, activation='relu')

        # Prediction Layer
        self.prediction = tf.keras.layers.Dense(num_classes, activation='softmax')
    
    def call(self, global_feature, local_feature):
        local_feature = self.pooled(local_feature)
        x = tf.concat([global_feature, local_feature], 1)
        x = tf.keras.layers.Dropout(0.2)(x)
        x = self.dense_1(x)
        x = self.prediction(x)
        return tf.math.reduce_mean(x, axis = 0)
    
    def build_graph(self, local_feature_shape = (149, 149, 32), global_feature_shape = (2048)):
        local_feature = tf.keras.layers.Input(shape = local_feature_shape, dtype='float32')
        global_feature = tf.keras.layers.Input(shape = global_feature_shape)
        
        return tf.keras.Model(inputs=[global_feature, local_feature], 
                              outputs=self.call(global_feature, local_feature))

In [13]:
cm = Causal_Model(num = 5)
cm.build_graph().summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 149, 149, 3  0           []                               
                                2)]                                                               
                                                                                                  
 input_2 (InputLayer)           [(None, 2048)]       0           []                               
                                                                                                  
 global_average_pooling2d (Glob  (None, 32)          0           ['input_1[0][0]']                
 alAveragePooling2D)                                                                              
                                                                                              

In [14]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)

In [15]:
@tf.function
def loss_model(model, global_feature, local_feature, y, class_weight, training=False):
    
    y_pred = model(global_feature, local_feature, training=training)
    loss = loss_object(y, y_pred, sample_weight=class_weight)
    return loss

@tf.function
def grad_cm(model, global_feature, local_feature, y, class_weight):
    with tf.GradientTape() as tape:
        loss_value = loss_model(model, global_feature, local_feature, y, class_weight, training=False)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

In [16]:
def repeat_tensor(arr, repeat):
    return tf.repeat(arr, repeats = repeat, axis = 0)

In [25]:
epochs = range(1)
num = 5
batch_size = 4

for epoch in tqdm(epochs, desc='epoch'):
    loss, steps = 0, 0 
    
    for row in tqdm(train_df.iterrows(), total=len(train_df)):
        row = row[1]
        
        img_id = row['image_id']
        label = row['dx']
        
        img_path = f'data/train/{label}/{img_id}.jpg'
        img = get_img(img_path).reshape(-1, 299,299, 3)
        
        global_feature = global_model(img)
        global_feature = repeat_tensor(global_feature, [num])
        
        local_feature = local_model(get_local_feature(train_df, label))
        
        label_idx = np.array([class_label[label]])        
        weight = np.array([weights[label_idx[0]]])

        loss_value, grads = grad_cm(cm, global_feature, local_feature, label_idx, weight)
        optimizer.apply_gradients(zip(grads, cm.trainable_weights))   
        
        loss += loss_value.numpy()
        steps += 1
    
    print(f"[{epoch}] Loss : {loss/steps}")    

epoch:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/8012 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [18]:
test_df = pd.read_csv('data/test_truth.csv')

In [19]:
def get_test_local_feature(df, num = 5):
    img_id = df.sample(num)['image_id']
    l = []
    for img in img_id:
        path = f'data/test/{img}.jpg'
        l.append(get_img(path))
    return np.array(l)

In [24]:
acc = 0
for row in tqdm(test_df.iterrows(), total=len(test_df)):
    row = row[1]
    
    img_path = row['image_id']
    img = get_img(f'data/test/{img_path}.jpg').reshape(-1, 299,299, 3)

    label = class_label[row['dx']]
    
    global_feature = global_model(img)
    global_feature = repeat_tensor(global_feature, [num])
    
    local_feature = local_model(get_test_local_feature(test_df))
    pred = np.argmax(cm(global_feature, local_feature))    
    
    if pred == label:
        acc += 1
    
print('Acc : ', (acc/len(test_df)))

  0%|          | 0/1001 [00:00<?, ?it/s]

Acc :  0.8061938061938062
