In [1]:
from IPython.display import IFrame
IFrame("https://bahramjafrasteh.github.io/#services", 1900,1500)

#***Unit 2: Hands-on Session: Applied deep learning for 3D medical image classification🚀***


In this notebook you will learn a basic of deep learning for ** 3D medical image classification**. 
The recommended python packages of this study are [NiBabel](https://nipy.org/nibabel/), [PyTorch](https://pytorch.org/).


❓ If you have any questions, please post them on #questions discord channel 👉 https://discord.gg/KqtpN5SF

## Objectives of this notebook 🏆
At the end of the notebook, you will:
- 🔍 Be able to correctly **load weights of a trained model**.
- 📚 Be able to send your **input** image to the network.
- 🧮 Be able to run your first and very basic **deep learning model**.


![picture](https://drive.google.com/uc?export=view&id=1wcqdmu9zSaYn_-uVGLd45V8UMYTM4sWx)

In this course, you will learn how to:

- 📖 use a python library to generate a DL architecture.
- 👩🧑 use some characteristics of **famous python libraries** such as **NiBabel**, **PyTorch**, **Matplotlib** to correctly inference on a pretrained model.


## Prerequisites 🏯
To follow this course you need to have basic knowlege of 3D images and their characteristics and image acquisition techniques 🤝.
- 💪 Basic knowledge of **python** 👩🏼‍💻 programming is **essential**. 
- 🥳 Mathematical understanding of the rotation matrix and algebric operation are very **helpful** .
- Basic knowledge of tensor's operation is also necessary and important to follow this course.

###**1) Installing the required packages**

**1. First Step**: Installing the required packages ⏳

During this notebook we need to access some package and then, we will install them here.

At first we need to know which python version we do have. Lets find it 🔎

In [None]:
%%capture
!pip3 install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
!pip3 install numpy
!pip3 install matplotlib
!pip3 install nibabel
!pip3 install SimpleITK
!pip3 install pydrive
!pip3 install mathjax

###**2) Building the deep CNN model**

In this section, we are building the first and the most basic deep convolutional neural network for classiffying an image using PyTorch library.
It contains three 3D convolutional layers.
The output size, O, for an input image with size I and Kernel size K, dilation D and stride S is computed as follows:

$O = (I + 2*P - D * (K - 1) - 1)/S + 1$


In [None]:
import torch.nn as nn
import torch
class SimpleNet(nn.Module):
    def __init__(self, *args, **kwargs):

        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv3d(1, 8, kernel_size=(3, 3, 3),
                                        stride=(4, 4, 4), padding=3 // 2,
                                        bias=True)
        self.activation = nn.ReLU()
        self.conv2 = nn.Conv3d(8, 16, kernel_size=(3, 3, 3),
                                        stride=(4, 4, 4), padding=3 // 2,
                                        bias=True)
        self.conv3 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3),
                                        stride=(4, 4, 4), padding=3 // 2,
                                        bias=True)

        self.last = nn.Linear(256, 1)
        self.last_activaiton = nn.Sigmoid()

    def forward(self, x, class_ind=None, train=True):
        y = self.activation(self.conv1(x[0]))
        y = self.activation(self.conv2(y))
        y = self.activation(self.conv3(y)).flatten(-4,-1)
        return self.last(y)

def inference(model, x):
  return model(x)


def load_model_weights(model, file, device):
  state_dict = torch.load(file, map_location=device)
  model.load_state_dict(state_dict['state_dict'])
  print('successfully loaded weights')
  return model


###**3) Copy the network weights and the input image for the inference**

In [None]:
import os
local_download_path = os.path.expanduser('~/weight')
import os
import numpy as np
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials



auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
try:
  os.makedirs(local_download_path)
except: pass
files = drive.ListFile(
    {'q': "'1QoKMmDM07DbLphCaH_Zo8B35uwk46G5J' in parents"}).GetList()

for f in files:
  if f['title'] == 'file_inp.nii.gz':
    flocal = os.path.join(local_download_path, f['title'])
    filenew = drive.CreateFile({'id': f['id']})
    filenew.GetContentFile(flocal)
    print(flocal)
    print('successfully coppied')
  if f['title'] == 'Hybrid_latest.pth':
    #print(f)
    flocal = os.path.join(local_download_path, f['title'])
    print(flocal)
    filenew = drive.CreateFile({'id': f['id']})
    filenew.GetContentFile(flocal)
   
    print('successfully coppied')

/root/weight/Hybrid_latest.pth
successfully coppied
/root/weight/file_inp.nii.gz
successfully coppied


###**4) Load model weights**

Task 🔑 Load the model in GPU
 
📜(*hint) Runtime change runtime type

In [None]:
model = SimpleNet()
load_model_weights(model, '/root/weight/Hybrid_latest.pth', 'cuda:0')

successfully loaded weights


SimpleNet(
  (conv1): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(4, 4, 4), padding=(1, 1, 1))
  (activation): ReLU()
  (conv2): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(4, 4, 4), padding=(1, 1, 1))
  (conv3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(4, 4, 4), padding=(1, 1, 1))
  (last): Linear(in_features=256, out_features=1, bias=True)
  (last_activaiton): Sigmoid()
)

###**5) Inference on the learned model using input image**

Now read the image to make the inference

In [None]:
inputimg = torch.from_numpy(image.astype(np.float32))
inputimg = inputimg.reshape((1,1,*inputimg.shape))

output = model.forward(inputimg)
print(output.item())

NameError: ignored

👉 Task Please solve the problems with reading the inputs.

📜(*hint) find the ideal input size for the network and then resize image

The ouput probability for classification of this image will be "0.51"

###**Questions**

 🕵🏽‍♂️ What do you suggest to improve this network?

 🕵🏽‍♂️ Do you see any problem with this network?

 🕵🏽‍♂️ Can you guess what can be the usage of this kind of the netowrk?