# 1.导入包

In [2]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import wandb
from evaluate import evaluate
from unet import UNet
from utils.data_loading_copy import BratsDataset
from utils.dice_score import dice_loss
from os.path import splitext, isfile, join
from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split
import tarfile

  from .autonotebook import tqdm as notebook_tqdm


# 2.导入数据集

## 路径定义

In [None]:
TRAIN_DATASET_PATH = './brain_images'
VALIDATION_DATASET_PATH = TRAIN_DATASET_PATH + '/val_brain_images'
TEST_DATASET_PATH = TRAIN_DATASET_PATH  + './test_brain_images'

## 解压文件(只执行一次)

In [None]:
# file = tarfile.open('/home/sucheng/Pytorch-UNet/data/BraTS2021_Training_Data.tar')
# file.extractall(TRAIN_DATASET_PATH)
# file.close()

# file = tarfile.open('/home/sucheng/Pytorch-UNet/data/BraTS2021_00621.tar')
# file.extractall(VALIDATION_DATASET_PATH)
# file.close()

# file = tarfile.open('/home/sucheng/Pytorch-UNet/data/BraTS2021_00495.tar')
# file.extractall(TEST_DATASET_PATH)
# file.close()

## 划分数据集

In [None]:
def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_val_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]
train_and_test_ids = pathListIntoIds(train_and_val_directories)


train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.2)
train_ids, test_ids = train_test_split(train_test_ids,test_size=0.15)

## 定义数据类

In [None]:
class BratsDataset(Dataset):
    def __init__(self, list_IDs, dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
        # 初始化Initialization
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, idx):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[idx*self.batch_size:(idx+1)*self.batch_size]
        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(Batch_ids)

        return X, y
    
    def on_epoch_end(self):
        'Updates indexes after each epoch '
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = torch.zeros((self.batch_size * VOLUME_SLICES, 2, *self.dim, self.n_channels))  # 使用torch.Tensor代替NumPy数组
        y = torch.zeros((self.batch_size * VOLUME_SLICES, 4, IMG_SIZE, IMG_SIZE))  # 只有一个通道的标签图像
        Y = torch.zeros((self.batch_size*VOLUME_SLICES, *self.dim, 4))
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAIN_DATASET_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii.gz')
            flair = nib.load(data_path).get_fdata()

            data_path = os.path.join(case_path, f'{i}_t1ce.nii.gz')
            ce = nib.load(data_path).get_fdata()

            data_path = os.path.join(case_path, f'{i}_seg.nii.gz')
            seg = nib.load(data_path).get_fdata()

            for j in range(VOLUME_SLICES):
                X[j+(VOLUME_SLICES*c),:,:,0] = torch.from_numpy(cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)))

                X[j+(VOLUME_SLICES*c),:,:,1] = torch.from_numpy(cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)))


                y[j +VOLUME_SLICES*c,:,:] = torch.from_numpy(cv2.resize(seg[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)))

        # Generate masks(One-Hot encoding)
        y[y==4] = 3
        y = torch.nn.functional.one_hot(y.to(torch.int64), num_classes=4).permute(0, 3, 1, 2).float()  # 进行One-Hot编码
        #Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE));
        #Y = np.array(Y).reshape(1,128,128,128)
        return X / X.max(), y  # 返回归一化后的输入图像和One-Hot编码后的标签

In [3]:
def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
    return parser.parse_args()

args = get_args()

usage: ipykernel_launcher.py [-h] [--epochs E] [--batch-size B]
                             [--learning-rate LR] [--load LOAD]
                             [--scale SCALE] [--validation VAL] [--amp]
                             [--bilinear] [--classes CLASSES]
ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"4b1c6b08-c2d9-415d-83ef-f438aad19e2e" --shell=9002 --transport="tcp" --iopub=9004 --f=/home/sucheng/.local/share/jupyter/runtime/kernel-v2-425752aVgMWZZfe1ed.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
model = model.to(memory_format=torch.channels_last)