# 基于脑PET图像的疾病预测挑战赛 CNN 版本

在这个notebook中我们使用CNN来训练PET的图像识别，判定给定图像是否为MCI患者
卷积神经网络（Convolutional Neural Network，CNN）是一种深度学习模型，广泛用于图像识别、计算机视觉和模式识别任务中。CNN 在处理具有网格结构数据（如图像）时表现出色，它能够自动学习和提取图像中的特征，并在分类、定位和分割等任务中取得优秀的性能。
我们使用Pytorch CNN来完成训练。
- CNN带来的精度更好，但需要训练更长的时间
- CNN模型调优需要GPU


## 步骤一 数据准备

导入所需要的数据库和函数调用库

In [3]:
# nibabel 需要从一些其他的源来调用。。。
!pip install nibabel -i https://pypi.douban.com/simple --trust -host=pypi.douban.com
import glob                # 获取文件路径
import numpy as np
import pandas as pd
import nibabel as nib      # 处理医学图像数据
from nibabel.viewers import OrthoSlicer3D    # 图像可视化
from collections import Counter              # 计数统计

Looking in indexes: https://pypi.douban.com/simple


DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [2]:
import os, sys, glob, argparse
import pandas as pd
import numpy as np
from tqdm import tqdm

这里添加一些导入的库的说明：
| Library / Module | Description | Common Features |
|--------------|-------------|-----------------|
| os           | Interact with the operating system | os.path.join(), os.listdir(), os.makedirs(), os.remove(), os.system(), os.getenv() |
| sys          | Access interpreter variables and functions | sys.argv, sys.exit(), sys.platform, sys.path |
| glob         | Find pathnames matching a pattern | glob.glob(), glob.iglob() |
| argparse     | Parse command-line arguments | Argument parsing with flags, positional arguments, default values, help messages |
| pandas       | Data manipulation and analysis | DataFrames, Series, data alignment, grouping, merging, handling missing data |
| numpy        | Arrays, matrices, mathematical functions | Multidimensional arrays, array operations, broadcasting, linear algebra, random generation |
| tqdm         | Progress bar for loops | Adding progress bars to loops using tqdm() function | 


In [6]:
pip install opencv-python

Note: you may need to restart the kernel to use updated packages.


DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [9]:
!pip install torch -q
# The install successful information can be annoying. can suppress by adding -q

DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [10]:
import cv2 #this needs opencv-python first, see above
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import torch #this needs install torch first, see above
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


| Library / Module | Description | Common Features |
|--------------|-------------|-----------------|
| cv2                       | OpenCV library for computer vision         | Image processing, computer vision tasks, video capture and analysis                                                |
| PIL (Image module)        | Python Imaging Library for image processing | Opening, manipulating, and saving various image formats                                                            |
| sklearn.model_selection  | Scikit-learn module for data splitting     | Data splitting for training and testing, cross-validation, stratified sampling, k-fold cross-validation            |
| torch                     | PyTorch library for deep learning          | Tensors, neural network modules, optimization algorithms, autograd for automatic differentiation                  |

In [12]:
!pip install torchvision -q

DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [38]:
import torch

In [41]:
import torchvision.models as models 
##before do this still need to install torchvision, this is separate extension library to torch 
# with tools specific to computer vision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
# This is the Dataset class we will use later!! (The parent class of XunFeiDataSet)
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# This is the nii imaging package. used to read nii images
import nibabel as nib
from nibabel.viewers import OrthoSlicer3D

| Library Name            | Description                                   | Common Features                                                                           |
|-------------------------|-----------------------------------------------|--------------------------------------------------------------------------------------------|
| torchvision.models      | Models and pre-trained models in PyTorch     | Various pre-defined deep learning models for image classification and feature extraction |
| torchvision.transforms  | Data transformations for images in PyTorch   | Image data augmentation, normalization, resizing, cropping, and more                     |
| torchvision.datasets   | Datasets for PyTorch                          | Access to standard datasets for training and testing, data loaders                        |
| torch.nn (nn module)    | Neural network layers and modules in PyTorch | Layers, activation functions, loss functions, optimizers, custom module definitions       |
| torch.nn.functional    | Functional interface to nn modules in PyTorch | Functional alternatives to some nn module operations                                      |
| torch.optim            | Optimization algorithms in PyTorch           | SGD, Adam (adaptive moment estimation, an extension of SGD), RMSprop, optimization functions                                                 |
| torch.autograd         | Automatic differentiation in PyTorch        | Computing gradients for gradient-based optimization                                       |
| torch.utils.data.dataset | Dataset class in PyTorch                    | Custom dataset creation for data loading and preprocessing                               |
| nibabel.viewers | methods for analysing nii images | OrthoSlicer3D: taking slices of image

In [63]:
# An import library for CV
pip install --user opencv-python

Note: you may need to restart the kernel to use updated packages.


DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [68]:
# albumentations is a library built on top of opencv, and focuses specifically on efficient and customizable image
# augmentation for machine learning and deep learning tasks. It streamlines the process of applying complex 
# augmentation pipelines to large datasets while utilizing the powerful image processing capabilities of OpenCV.
%pip install --user albumentations -q

Note: you may need to restart the kernel to use updated packages.


DEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [92]:
import albumentations as A

| Library Name     | Description                                      | Common Features                                                                                |
|------------------|--------------------------------------------------|------------------------------------------------------------------------------------------------|
| albumentations   | Image augmentation library for machine learning | Wide range of image augmentation techniques, compatible with various deep learning frameworks (compared with opencv "more user friendly") |


需要用的库已经导入完毕，接下来开始导入并处理数据

In [88]:
train_path = glob.glob('./PETdata/train/*/*')
# 仍然，注意到Train文件夹下有MCI和NC两个附属文件夹，里面的文件都需要读取，用/*/*
test_path = glob.glob('./PETdata/Test/*')

np.random.shuffle(train_path)
np.random.shuffle(test_path)

# DATA_CACHE is a dictionary. The keys are path names, the values are np3darray (dataobj of nib images)
# The DATA_CACHE is a dictionary used for caching loaded images.
DATA_CACHE = {}

# 定义一个Class XunFeiDataset,参数是一堆需要处理的图像的paths("img_path")，以及需要进行的图像变换("transform")
# 注意到定义时添加了parent class "Dataset" (torch.utils.data.Dataset)，这代表任何定义为属于XunFeiDataset class的实例都将同时继承Dataset class的methods
# 前序进行了 from torch.utils.data.dataset import Dataset这里指的是torch library里叫做Dataset的class
class XunFeiDataset(Dataset):
    # 一个左侧加了双下划线的method是一个private method, private method不能直接被所属instance调用。具体参考onenote说明。
    # 注意这里的__init__和__getitem__两侧都有双下划线不是说他们是private method，而说明他们是特殊的一种dunder method (double underline)
    # 这里的__init__constructor method来初始化一个属于Class的instance的一些属性。
    # 如self := IMAGES = XunFeiDataset(img_path = ['abc','def',....], transform = A.Compose([...]) 作为一个instance
    # （一些Albumentation库中导入的transform的method的拼接，具体见下接的markdown),
    # 那么IMAGES.img_path = ['abc','def','ghi',....], IMAGES.transform = A.Compose([...])
    def __init__(self, img_path, transform=None):
        self.img_path = img_path
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    # The __getitem__ method is used to implement indexing behavior for instances of the class. 
    # e.g. IMAGES[3] does the following: 
    # Check if the path given by IMAGES.img_path[3] is in DATA_CACHE: ...
    # (if not using __getitem__ to define, can't type IMAGES[3], would have to type IMAGES.img_path[3] instead)
    def __getitem__(self, index):
        # Checks whether the specified path of (IMAGES.img_path[index]) is already present 
        # in the DATA_CACHE's keys.
        if self.img_path[index] in DATA_CACHE:
            # If the path is found in DATA_CACHE's keys, the corresponding cached np3darray image from the cache is loaded
            # and assigned to the img variable.
            img = DATA_CACHE[self.img_path[index]]
        else:
            # If the path is not in DATA_CACHE's keys, it uses the nib.load function from the nibabel library 
            # to load the NIfTI format image from the path specified by self.img_path[index] to the img variable.
            img = nib.load(self.img_path[index]) 
            # img then accesses the dataobj attribute of the loaded image. 
            # The dataobj attribute contains the actual image data in a NumPy array-like object.
            # The indexing [:, :, :, 0] is used to select data from the first volume (assuming the image is 4D), 
            # effectively extracting a 3D slice from the image (restriction to one coloring passage)
            img = img.dataobj[:,:,:, 0]
            # Finally, the loaded or extracted image is cached in the DATA_CACHE dictionary using the given path as the key.
            DATA_CACHE[self.img_path[index]] = img
    
        # Then, a random of 50 slices of the 3Darray are selected (with replacement) 
        # and combined to make the updated img 3Darray      
        idx = np.random.choice(range(img.shape[-1]), 50)
        img = img[:, :, idx]
        # Finally converts the data stored in the img 3Darray to floats for later use
        img = img.astype(np.float32)
        
        # IMAGES是一个XunFeiDataset的实例（也继承成为Dataset的实例），
        # if IMAGES.transform is not empty (e.g. see next code block, transform loads a class A.Compos
        if self.transform is not None:
            img = self.transform(image = img)['image']
        
        
        # We take the transpose of the image tensor,
        # make the z-coponent be inquired first
        img = img.transpose([2,0,1])
        return img,torch.from_numpy(np.array(int('NC' in self.img_path[index])))
        
    def __len__(self):
        return len(self.img_path)


有关一些即将被导进“transform”这个参数的methods：他们到底是哪儿来的：

| Method                  | Module                                               |
|-------------------------|------------------------------------------------------|
| A.RandomRotate90        | albumentations.augmentations.geometric.rotate       |
| A.RandomCrop            | albumentations.augmentations.crops.transforms       |
| A.HorizontalFlip        | albumentations.augmentations.geometric.transforms   |
| A.RandomContrast        | albumentations.augmentations.transforms             |
| A.RandomBrightnessContrast | albumentations.augmentations.transforms          |

都是在albumentations库里，只不过藏得到处都是。。。

In [116]:
# train_loader is a DataLoader class instance, inheriting from the XunFeiDataset instance, inheriting from Dataset instance
# It intakes a Dataset class instance (needs to be torch.utils.data.Dataset class) and 
# the sampling parameters: batch_size, shuffle, num_workers...
# For the output, it provides an iterable over the Dataset instance.
train_loader = torch.utils.data.DataLoader(
    # this train_loader instance is filled with a XunFeiDataset class instance (which is, a Dataset Class instance), 
        ## with img_path defined to be all paths in the list train_path all but last ten images
        ## and transform being the composition of those methods in albumentations library
    # and sampling parameters: batch_size = 2, shuffle = True,...

    # in the above context, this would mean self.transform = A.Compose([A.RandomRotate90(),...,A.RandomBrightnesContrast(p=0.5)])
    XunFeiDataset(img_path = train_path[:-10], transform = 
            # These are methods for transforming images
            A.Compose([
            A.RandomRotate90(),
            A.RandomCrop(120, 120),
            A.HorizontalFlip(p=0.5),
            A.RandomContrast(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
        ])
        # The num_workers parameter specifies the number of worker processes to use for data loading. 
        # Each worker process loads a batch of data independently in parallel, which can significantly speed up data loading.
        # The number of workers should be chosen based on your system's capabilities and the nature of the dataset.
        # using too many workers could overload your CPU and RAM.

        # Setting pin_memory to True is relevant when using GPUs for training. 
        # which can lead to faster data transfers between CPU and GPU.
    ), batch_size=2, shuffle=True, num_workers=1, pin_memory=False
    
)

# This is loading validation images... Still take sampling from XunFeiDataset instance, 
# but we only use the last ten paths in the train_path 
# Also this time only allow crops, 
# as true images we need to classify are never horizontally flipped, changed contrast etc.
val_loader = torch.utils.data.DataLoader(
    XunFeiDataset(train_path[-10:],
            A.Compose([
            A.RandomCrop(120, 120),
        ])
    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False
)

# This is loading test images... Still take sampling from XunFeiDataset instance,
# but we use test_path
test_loader = torch.utils.data.DataLoader(
    XunFeiDataset(test_path,
            A.Compose([
            A.RandomCrop(128, 128),
            A.HorizontalFlip(p=0.5),
            A.RandomContrast(p=0.5),
        ])
    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False
)




## 步骤二 自定义CNN模型

这里XunFeiNet模型架构选择的是

In [125]:
# defines a new class named XunFeiNet that inherits from nn.Module. 
# This means that XunFeiNet is a PyTorch neural network model.
class XunFeiNet(nn.Module):
    def __init__(self):
        # When calling super() with the subclass and an instance of that subclass (self), 
        # it allows you to access and call methods from its parent classes.
        # Here, the code calls the constructor of the parent class (nn.Module) to properly initialize the class.
        # This could also be done by nn.Module.__init__(self)
        # However when class hierarchy is complicated, it's best to use "super" instead of calling parentclass.__init__(self)
        # because "super" ensures that the constructors of all parent classes are called in the correct order.
        super(XunFeiNet, self).__init__()

        # creates an instance of the pre-trained ResNet-18 model using the models module from torchvision. 
        # The True argument specifies that the pre-trained weights should be loaded. (transfer learning)
        model = models.resnet18(True)
        # replaces the first convolutional layer of the ResNet-18 model with a new layer that takes 50 input channels 
        # (since there are 50 slices in each img.)
        # instead of the default 3. The kernel size, stride, and padding are adjusted accordingly.
        model.conv1 = torch.nn.Conv2d(50, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # replaces the average pooling layer of the ResNet-18 model with an adaptive average pooling layer 
        # that produces output of size 1x1.
        model.avgpool = nn.AdaptiveAvgPool2d(1)
        # replaces the fully connected layer (last layer before softmax) of the ResNet-18 model with a new linear layer 
        # for binary classification (2 output classes)
        model.fc = nn.Linear(512, 2)
        self.resnet = model
        
    # defines the forward pass of the neural network
    def forward(self, img): 
        # performs the forward pass by passing the input img through the modified ResNet-18 model,
        # then return the output tensor
        out = self.resnet(img)
        return out


model = XunFeiNet()
# moves the model to the GPU device for faster computation (assuming a GPU is available).
# Cuda is not available on my system because I don't have a compatible NVIDIA GPU...
# model = model.to('cuda')
criterion = nn.CrossEntropyLoss()#.cuda()
# creates the Cross-Entropy loss function and moves it to the GPU.
optimizer = torch.optim.AdamW(model.parameters(), 0.001)
# initializes the AdamW optimizer to update the model's parameters during training using the specified learning rate (0.001).

## 步骤三 模型训练与验证

In [128]:
def train(train_loader, model, criterion, optimizer):
    model.train()
    train_loss = 0.0
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        output = model(input)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 20 == 0:
            print(loss.item())
            
        train_loss += loss.item()
    
    return train_loss/len(train_loader)
            
def validate(val_loader, model, criterion):
    model.eval()
    val_acc = 0.0
    
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda()
            target = target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)
            
            val_acc += (output.argmax(1) == target).sum().item()
            
    return val_acc / len(val_loader.dataset)
    
for _  in range(3):
    train_loss = train(train_loader, model, criterion, optimizer)
    val_acc  = validate(val_loader, model, criterion)
    train_acc = validate(train_loader, model, criterion)
    
    print(train_loss, train_acc, val_acc)

RuntimeError: DataLoader worker (pid(s) 16104) exited unexpectedly

## 步骤四 模型预测与生成csv结果

In [None]:
def predict(test_loader, model, criterion):
    model.eval()
    val_acc = 0.0
    
    test_pred = []
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input = input.cuda()
            target = target.cuda()

            output = model(input)
            test_pred.append(output.data.cpu().numpy())
            
    return np.vstack(test_pred)
    
pred = None
for _ in range(10):
    if pred is None:
        pred = predict(test_loader, model, criterion)
    else:
        pred += predict(test_loader, model, criterion)
        
submit = pd.DataFrame(
    {
        'uuid': [int(x.split('/')[-1][:-4]) for x in test_path],
        'label': pred.argmax(1)
})
submit['label'] = submit['label'].map({1:'NC', 0: 'MCI'})
submit = submit.sort_values(by='uuid')
submit.to_csv('submit2.csv', index=None)