<a href="https://colab.research.google.com/github/Jamaleldin/MRNET-for-Knee-Diagnosis/blob/contribution/MRI_Contrast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Needed Imports

In [0]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, BatchNormalization
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import optimizers

# Extractor Model

In [0]:
# This is the feature extractor model based on the architecture used in the paper.
def extractor():
    model = Sequential()
    # The main architecture is based on cnn layers of 3x3 filters
    # The extractor consisits of 3 similar blocks:
    # The first block has 32 filters in convolutional layers, the second has 64 and the third has 128
    # Each block has 2 convolutional layers separated by batch normalizations and a relu activation layer
    # between each 2 blocks there is a max pooling layer of size 2x2
    
    # First block
    model.add(Conv2D(32, (3, 3), padding='same', input_shape=(256,256,3)))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3)))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    # Second block
    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3, 3)))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    # Third block
    model.add(Conv2D(128, (3, 3), padding='same'))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(Conv2D(128, (3, 3)))
    model.add(BatchNormalization(axis=1))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2))) 
   
   
    # This part is to be removed as the same approach used before
    # model.add(Flatten(name='flat_021'))
    # model.add(Dense(1024,name='dense_022'))
    # model.add(BatchNormalization(axis=1,name='batch_0023'))
    # model.add(Activation('relu',name='relu_024'))
    # model.add(Dense(1,name='dense_025'))
    # model.add(Activation('sigmoid',name='soft_026'))

    model.compile(loss='binary_crossentropy', optimizer=optimizers.Adam(1e-5), metrics=['acc'])
    return model

In [15]:
extractor().summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_19 (Conv2D)           (None, 256, 256, 32)      896       
_________________________________________________________________
batch_normalization_18 (Batc (None, 256, 256, 32)      1024      
_________________________________________________________________
activation_18 (Activation)   (None, 256, 256, 32)      0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 254, 254, 32)      9248      
_________________________________________________________________
batch_normalization_19 (Batc (None, 254, 254, 32)      1016      
_________________________________________________________________
activation_19 (Activation)   (None, 254, 254, 32)      0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 127, 127, 32)     

# Train Extractors


### Sagittal Abnormal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.sagittal, processing.abnormal)


### Sagittal ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.sagittal, processing.acl)


### Sagittal Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.sagittal, processing.meniscal)

### Axial Abnormal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.abnormal)

### Axial ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.acl)

### Axial Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.meniscal)

### Coronal Abnormal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.abnormal)

### Coronal ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.acl)

### Coronal Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.meniscal)

# Train Classifiers

In [0]:
def classifier():
  classifier = Sequential()
  classifier.add(Dense(1024)
  classifier.add(BatchNormalization(axis=1)
  classifier.add(Activation('relu'))
  classifier.add(Dense(1))
  classifier.add(Activation('sigmoid'))
  classifier.compile(loss='binary_crossentropy', optimizer=optimizers.Adam(1e-5), metrics=['acc'])
  return classifier


### Sagittal Abnormal

In [0]:
loaded_model = processing.load_model(processing.inception, processing.sagittal, processing.abnormal, processing.extractor)
extractor_model = Model(inputs=loaded_model.input, outputs=loaded_model.get_layer('block5_2').output)
extractor_model.summary()
#making classifier
classifier_model = classifier()
#train classifier
processing.train_classifier(extractor_model, classifier_model, processing.inception, processing.sagittal, processing.abnormal)


### Sagittal ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.sagittal, processing.acl)


### Sagittal Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.sagittal, processing.meniscal)

### Axial Abnormal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.abnormal)

### Axial ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.acl)

### Axial Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.axial, processing.meniscal)

### Coronal Abnormal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.abnormal)

### Coronal ACL

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.acl)

### Coronal Meniscal

In [0]:
contrast_cnn = getCNN()
processing.train_extractor(contrast_cnn, processing.inception, processing.coronal, processing.meniscal)