## 癌症检测

#### 流程：

1. 数据下载与预处理，得到图片为224* 224* 3大小，标准化，并且将数据类别进行one-hot编码
2. 数据增强
3. 权重初始化
4. L2正则化
5. 建立神经网络，使用VGG16，先训练bottleneck features，再将前面的卷积层加入一起训练。
6. 绘制损失曲线，并且可视化第一层的权重。
7. 将数据分为5份，交叉验证。

尝试：fancy PCA，数据预训练（无监督）

In [1]:
# download files
from os.path import isdir, isfile
from os import mkdir
from tqdm import tqdm
from urllib.request import urlretrieve
import urllib

class DLProgress(tqdm):
    last_block = 0
    def hook(self, block_num = 1, block_size = 1, total_size = None):
        self.total = total_size
        self.update((block_num-self.last_block)*block_size)
        self.last_block = block_num

headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:23.0) Gecko/20100101 Firefox/23.0'}  

if not isfile('./train.zip'):
    with DLProgress(unit = 'B',unit_scale = True, miniters=1, desc = 'train.zip') as pbar:        
        url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/train.zip'
        req = urllib.request.Request(url=url, headers=headers)  
        urlretrieve(
            url,
            './train.zip',
            pbar.hook)
else:
    print('Training file already exists!')

if not isfile('./valid.zip'):
    with DLProgress(unit = 'B',unit_scale = True, miniters=1, desc = 'valid.zip') as pbar:
        url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/valid.zip'
        req = urllib.request.Request(url=url, headers=headers)  
        urlretrieve(
            url,
            './valid.zip',
            pbar.hook)
else:
    print('Validation file already exists!')

    
if not isfile('./test.zip'):
    with DLProgress(unit = 'B',unit_scale= True, miniters=1, desc = 'test.zip') as pbar:
        url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/test.zip'
        req = urllib.request.Request(url=url, headers=headers)  
        urlretrieve(
            url,
            './test.zip',
            pbar.hook)
else:
    print('Test file already exists!')

Training file already exists!
Validation file already exists!
Test file already exists!


In [2]:
# unzip data
import zipfile

with zipfile.ZipFile('./train.zip','r') as f:
    for file in tqdm(f.namelist()):
        f.extract(file,'./')
with zipfile.ZipFile('./valid.zip','r') as f:
    for file in tqdm(f.namelist()):
        f.extract(file,'./')
with zipfile.ZipFile('./test.zip','r') as f:
    for file in tqdm(f.namelist()):
        f.extract(file,'./')

In [2]:
from glob import glob
from keras.preprocessing import image
from PIL import ImageFile
import numpy as np
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True
def path_to_tensor(img_path):
    img = image.load_img(img_path, target_size = (224,224))
    x = image.img_to_array(img)
    return np.expand_dims(x, axis = 0)
def paths_to_tensor(img_paths):
    list_of_tensor = [path_to_tensor(img_path) for img_path in tqdm(img_paths)]
    return np.vstack(list_of_tensor)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
#split data for future test
from keras.utils import np_utils
def process_data(filepath):
    class_name = [folder for folder in glob(filepath+'/*/') if isdir(folder)]
    features = np.array([])
    labels = np.array([])
    for i in range(len(class_name)):
        print(filepath+class_name[i])
        filename = class_name[i].replace('\\','/')
        contents = paths_to_tensor(glob(filename+'/*'))
        if features.shape == (0,):
            features = contents
        else:
            features = np.concatenate((features,contents))
        if i == 0:
            labels = np.zeros((len(contents),1))
        else:
            add_label = np.array([[i]]*len(contents))
            labels = np.concatenate((labels, add_label),axis = 0)
        i += 1
    labels = labels.flatten()
    labels = np_utils.to_categorical(labels)
    print('Feature shape: ', features.shape)
    print('Label shape:', labels.shape)
    return features, labels

In [None]:
from os.path import isdir, isfile
from os import mkdir
from tqdm import tqdm
from urllib.request import urlretrieve
import urllib
train_features,train_labels = process_data('./train')
valid_features,valid_labels = process_data('./valid')
test_features,test_labels = process_data('./test')

./train./train\melanoma\



  0%|                                                                                          | 0/374 [00:00<?, ?it/s]
  1%|▉                                                                                 | 4/374 [00:00<00:10, 34.46it/s]
  2%|█▌                                                                                | 7/374 [00:00<00:12, 29.77it/s]
  3%|██▏                                                                              | 10/374 [00:00<00:12, 28.71it/s]
  3%|██▊                                                                              | 13/374 [00:00<00:12, 28.43it/s]
  4%|███▍                                                                             | 16/374 [00:00<00:12, 27.85it/s]
  5%|████                                                                             | 19/374 [00:00<00:12, 27.36it/s]
  6%|████▊                                                                            | 22/374 [00:01<00:17, 20.58it/s]
  6%|█████▏                            

 79%|███████████████████████████████████████████████████████████████                 | 295/374 [00:38<00:10,  7.62it/s]
 79%|███████████████████████████████████████████████████████████████▎                | 296/374 [00:39<00:10,  7.58it/s]
 79%|███████████████████████████████████████████████████████████████▌                | 297/374 [00:39<00:10,  7.54it/s]
 80%|███████████████████████████████████████████████████████████████▋                | 298/374 [00:39<00:10,  7.51it/s]
 80%|███████████████████████████████████████████████████████████████▉                | 299/374 [00:40<00:10,  7.44it/s]
 80%|████████████████████████████████████████████████████████████████▏               | 300/374 [00:40<00:09,  7.40it/s]
 80%|████████████████████████████████████████████████████████████████▍               | 301/374 [00:40<00:09,  7.37it/s]
 81%|████████████████████████████████████████████████████████████████▌               | 302/374 [00:41<00:09,  7.32it/s]
 81%|███████████████████████████████████

./train./train\nevus\



  0%|                                                                                         | 0/1372 [00:00<?, ?it/s]
  0%|                                                                                 | 2/1372 [00:00<01:22, 16.53it/s]
  0%|▏                                                                                | 4/1372 [00:00<01:22, 16.57it/s]
  0%|▎                                                                                | 5/1372 [00:00<01:34, 14.47it/s]
  1%|▍                                                                                | 7/1372 [00:00<01:36, 14.10it/s]
  1%|▌                                                                                | 9/1372 [00:00<01:33, 14.50it/s]
  1%|▋                                                                               | 11/1372 [00:00<01:31, 14.82it/s]
  1%|▊                                                                               | 13/1372 [00:00<01:32, 14.62it/s]
  1%|▊                                 

 18%|█████████████▉                                                                 | 242/1372 [00:22<01:45, 10.74it/s]
 18%|█████████████▉                                                                 | 243/1372 [00:22<01:45, 10.72it/s]
 18%|██████████████                                                                 | 244/1372 [00:22<01:45, 10.71it/s]
 18%|██████████████                                                                 | 245/1372 [00:23<01:45, 10.64it/s]
 18%|██████████████▏                                                                | 246/1372 [00:23<01:47, 10.50it/s]
 18%|██████████████▏                                                                | 247/1372 [00:23<01:47, 10.47it/s]
 18%|██████████████▎                                                                | 248/1372 [00:23<01:47, 10.43it/s]
 18%|██████████████▎                                                                | 249/1372 [00:24<01:48, 10.33it/s]
 18%|██████████████▍                    

 30%|████████████████████████                                                       | 417/1372 [00:55<02:06,  7.56it/s]
 30%|████████████████████████                                                       | 418/1372 [00:55<02:06,  7.53it/s]
 31%|████████████████████████▏                                                      | 419/1372 [00:55<02:06,  7.52it/s]
 31%|████████████████████████▏                                                      | 420/1372 [00:56<02:07,  7.48it/s]
 31%|████████████████████████▏                                                      | 421/1372 [00:56<02:07,  7.44it/s]
 31%|████████████████████████▎                                                      | 422/1372 [00:56<02:08,  7.41it/s]
 31%|████████████████████████▎                                                      | 423/1372 [00:57<02:08,  7.37it/s]
 31%|████████████████████████▍                                                      | 424/1372 [00:57<02:08,  7.36it/s]
 31%|████████████████████████▍          

 51%|████████████████████████████████████████▍                                      | 703/1372 [01:28<01:24,  7.91it/s]
 51%|████████████████████████████████████████▋                                      | 706/1372 [01:29<01:23,  7.93it/s]
 52%|████████████████████████████████████████▊                                      | 709/1372 [01:29<01:23,  7.95it/s]
 52%|████████████████████████████████████████▉                                      | 712/1372 [01:29<01:22,  7.98it/s]
 52%|█████████████████████████████████████████▏                                     | 715/1372 [01:29<01:22,  8.00it/s]
 52%|█████████████████████████████████████████▎                                     | 718/1372 [01:29<01:21,  8.01it/s]
 53%|█████████████████████████████████████████▌                                     | 721/1372 [01:29<01:21,  8.03it/s]
 53%|█████████████████████████████████████████▋                                     | 724/1372 [01:29<01:20,  8.05it/s]
 53%|███████████████████████████████████

 68%|█████████████████████████████████████████████████████▊                         | 935/1372 [02:07<00:59,  7.32it/s]
 68%|█████████████████████████████████████████████████████▉                         | 936/1372 [02:08<00:59,  7.31it/s]
 68%|█████████████████████████████████████████████████████▉                         | 937/1372 [02:08<00:59,  7.31it/s]
 68%|██████████████████████████████████████████████████████                         | 938/1372 [02:08<00:59,  7.31it/s]
 68%|██████████████████████████████████████████████████████                         | 939/1372 [02:08<00:59,  7.30it/s]
 69%|██████████████████████████████████████████████████████▏                        | 940/1372 [02:08<00:59,  7.30it/s]
 69%|██████████████████████████████████████████████████████▏                        | 941/1372 [02:08<00:59,  7.30it/s]
 69%|██████████████████████████████████████████████████████▏                        | 942/1372 [02:09<00:58,  7.30it/s]
 69%|███████████████████████████████████

 78%|████████████████████████████████████████████████████████████▉                 | 1071/1372 [02:42<00:45,  6.59it/s]
 78%|████████████████████████████████████████████████████████████▉                 | 1072/1372 [02:42<00:45,  6.59it/s]
 78%|█████████████████████████████████████████████████████████████                 | 1073/1372 [02:42<00:45,  6.58it/s]
 78%|█████████████████████████████████████████████████████████████                 | 1074/1372 [02:43<00:45,  6.58it/s]
 78%|█████████████████████████████████████████████████████████████                 | 1075/1372 [02:43<00:45,  6.58it/s]
 78%|█████████████████████████████████████████████████████████████▏                | 1076/1372 [02:43<00:45,  6.58it/s]
 78%|█████████████████████████████████████████████████████████████▏                | 1077/1372 [02:43<00:44,  6.57it/s]
 79%|█████████████████████████████████████████████████████████████▎                | 1078/1372 [02:44<00:44,  6.57it/s]
 79%|███████████████████████████████████

 88%|████████████████████████████████████████████████████████████████████▌         | 1207/1372 [03:27<00:28,  5.81it/s]
 88%|████████████████████████████████████████████████████████████████████▋         | 1208/1372 [03:28<00:28,  5.80it/s]
 88%|████████████████████████████████████████████████████████████████████▋         | 1209/1372 [03:28<00:28,  5.80it/s]
 88%|████████████████████████████████████████████████████████████████████▊         | 1210/1372 [03:28<00:27,  5.79it/s]
 88%|████████████████████████████████████████████████████████████████████▊         | 1211/1372 [03:29<00:27,  5.79it/s]
 88%|████████████████████████████████████████████████████████████████████▉         | 1212/1372 [03:29<00:27,  5.78it/s]
 88%|████████████████████████████████████████████████████████████████████▉         | 1213/1372 [03:29<00:27,  5.78it/s]
 88%|█████████████████████████████████████████████████████████████████████         | 1214/1372 [03:30<00:27,  5.78it/s]
 89%|███████████████████████████████████

 98%|████████████████████████████████████████████████████████████████████████████▎ | 1343/1372 [04:55<00:06,  4.54it/s]
 98%|████████████████████████████████████████████████████████████████████████████▍ | 1344/1372 [04:56<00:06,  4.53it/s]
 98%|████████████████████████████████████████████████████████████████████████████▍ | 1345/1372 [04:57<00:05,  4.52it/s]
 98%|████████████████████████████████████████████████████████████████████████████▌ | 1346/1372 [04:58<00:05,  4.51it/s]
 98%|████████████████████████████████████████████████████████████████████████████▌ | 1347/1372 [04:59<00:05,  4.49it/s]
 98%|████████████████████████████████████████████████████████████████████████████▋ | 1348/1372 [05:00<00:05,  4.48it/s]
 98%|████████████████████████████████████████████████████████████████████████████▋ | 1349/1372 [05:01<00:05,  4.47it/s]
 98%|████████████████████████████████████████████████████████████████████████████▋ | 1350/1372 [05:02<00:04,  4.46it/s]
 98%|███████████████████████████████████

./train./train\seborrheic_keratosis\



  0%|                                                                                          | 0/254 [00:00<?, ?it/s]
  0%|▎                                                                                 | 1/254 [00:00<01:56,  2.18it/s]
  1%|▋                                                                                 | 2/254 [00:00<01:22,  3.04it/s]
  1%|▉                                                                                 | 3/254 [00:00<01:13,  3.42it/s]
  2%|█▎                                                                                | 4/254 [00:01<01:07,  3.69it/s]
  2%|█▌                                                                                | 5/254 [00:01<01:04,  3.84it/s]
  2%|█▉                                                                                | 6/254 [00:01<01:03,  3.92it/s]
  3%|██▎                                                                               | 7/254 [00:01<01:02,  3.94it/s]
  3%|██▌                               

 54%|███████████████████████████████████████████▍                                    | 138/254 [00:51<00:43,  2.68it/s]
 55%|███████████████████████████████████████████▊                                    | 139/254 [00:52<00:43,  2.63it/s]
 55%|████████████████████████████████████████████                                    | 140/254 [00:53<00:43,  2.61it/s]
 56%|████████████████████████████████████████████▍                                   | 141/254 [00:55<00:44,  2.55it/s]
 56%|████████████████████████████████████████████▋                                   | 142/254 [00:57<00:45,  2.47it/s]
 56%|█████████████████████████████████████████████                                   | 143/254 [00:58<00:45,  2.43it/s]
 57%|█████████████████████████████████████████████▎                                  | 144/254 [00:59<00:45,  2.40it/s]
 57%|█████████████████████████████████████████████▋                                  | 145/254 [01:01<00:46,  2.36it/s]
 57%|███████████████████████████████████

Feature shape:  (2000, 224, 224, 3)
Label shape: (2000, 3)
./valid./valid\melanoma\



  0%|                                                                                           | 0/30 [00:00<?, ?it/s]
  3%|██▊                                                                                | 1/30 [00:00<00:06,  4.71it/s]
  7%|█████▌                                                                             | 2/30 [00:00<00:06,  4.23it/s]
 10%|████████▎                                                                          | 3/30 [00:01<00:10,  2.67it/s]
 13%|███████████                                                                        | 4/30 [00:01<00:09,  2.75it/s]
 17%|█████████████▊                                                                     | 5/30 [00:01<00:08,  2.89it/s]
 20%|████████████████▌                                                                  | 6/30 [00:01<00:07,  3.05it/s]
 23%|███████████████████▎                                                               | 7/30 [00:02<00:08,  2.72it/s]
 27%|██████████████████████▏           

./valid./valid\nevus\



  0%|                                                                                           | 0/78 [00:00<?, ?it/s]
  1%|█                                                                                  | 1/78 [00:00<00:17,  4.29it/s]
  3%|██▏                                                                                | 2/78 [00:00<00:22,  3.40it/s]
  4%|███▏                                                                               | 3/78 [00:00<00:20,  3.58it/s]
  5%|████▎                                                                              | 4/78 [00:01<00:28,  2.60it/s]
  6%|█████▎                                                                             | 5/78 [00:02<00:34,  2.12it/s]
  8%|██████▍                                                                            | 6/78 [00:03<00:38,  1.85it/s]
  9%|███████▍                                                                           | 7/78 [00:03<00:37,  1.88it/s]
 10%|████████▌                         

./valid./valid\seborrheic_keratosis\



  0%|                                                                                           | 0/42 [00:00<?, ?it/s]
  2%|█▉                                                                                 | 1/42 [00:01<01:12,  1.76s/it]
  5%|███▉                                                                               | 2/42 [00:03<01:04,  1.61s/it]
  7%|█████▉                                                                             | 3/42 [00:04<00:58,  1.51s/it]
 10%|███████▉                                                                           | 4/42 [00:06<00:58,  1.54s/it]
 12%|█████████▉                                                                         | 5/42 [00:07<00:52,  1.42s/it]
 14%|███████████▊                                                                       | 6/42 [00:08<00:51,  1.44s/it]
 17%|█████████████▊                                                                     | 7/42 [00:09<00:49,  1.43s/it]
 19%|███████████████▊                  

Feature shape:  (150, 224, 224, 3)
Label shape: (150, 3)
./test./test\melanoma\



  0%|                                                                                          | 0/117 [00:00<?, ?it/s]
  1%|▋                                                                                 | 1/117 [00:00<00:57,  2.01it/s]
  2%|█▍                                                                                | 2/117 [00:00<00:41,  2.75it/s]
  3%|██                                                                                | 3/117 [00:01<00:59,  1.90it/s]
  3%|██▊                                                                               | 4/117 [00:02<00:57,  1.98it/s]
  4%|███▌                                                                              | 5/117 [00:02<00:59,  1.88it/s]
  6%|████▉                                                                             | 7/117 [00:02<00:46,  2.37it/s]
  7%|█████▌                                                                            | 8/117 [00:03<00:43,  2.51it/s]
  8%|██████▎                           

./test./test\nevus\



  0%|                                                                                          | 0/393 [00:00<?, ?it/s]
  0%|▏                                                                                 | 1/393 [00:00<03:10,  2.05it/s]
  1%|▍                                                                                 | 2/393 [00:00<02:27,  2.65it/s]
  1%|▋                                                                                 | 3/393 [00:01<02:21,  2.75it/s]
  1%|▊                                                                                 | 4/393 [00:01<02:09,  3.01it/s]
  1%|█                                                                                 | 5/393 [00:01<02:01,  3.19it/s]
  2%|█▎                                                                                | 6/393 [00:01<01:55,  3.34it/s]
  2%|█▍                                                                                | 7/393 [00:02<01:51,  3.46it/s]
  2%|█▋                                

 35%|███████████████████████████▋                                                    | 136/393 [02:49<05:20,  1.25s/it]
 35%|███████████████████████████▉                                                    | 137/393 [02:50<05:19,  1.25s/it]
 35%|████████████████████████████                                                    | 138/393 [02:52<05:17,  1.25s/it]
 35%|████████████████████████████▎                                                   | 139/393 [02:53<05:16,  1.25s/it]
 36%|████████████████████████████▍                                                   | 140/393 [02:54<05:14,  1.24s/it]
 36%|████████████████████████████▋                                                   | 141/393 [02:55<05:13,  1.24s/it]
 36%|████████████████████████████▉                                                   | 142/393 [02:56<05:11,  1.24s/it]
 36%|█████████████████████████████                                                   | 143/393 [02:57<05:10,  1.24s/it]
 37%|█████████████████████████████▎     

 69%|███████████████████████████████████████████████████████▎                        | 272/393 [08:54<03:57,  1.96s/it]
 69%|███████████████████████████████████████████████████████▌                        | 273/393 [08:54<03:54,  1.96s/it]
 70%|███████████████████████████████████████████████████████▊                        | 274/393 [08:54<03:52,  1.95s/it]
 70%|███████████████████████████████████████████████████████▉                        | 275/393 [08:55<03:49,  1.95s/it]
 70%|████████████████████████████████████████████████████████▏                       | 276/393 [08:56<03:47,  1.94s/it]
 70%|████████████████████████████████████████████████████████▍                       | 277/393 [08:56<03:44,  1.94s/it]
 71%|████████████████████████████████████████████████████████▌                       | 278/393 [08:56<03:42,  1.93s/it]
 71%|████████████████████████████████████████████████████████▊                       | 279/393 [08:56<03:39,  1.92s/it]
 71%|███████████████████████████████████

In [None]:
# standardize data
from sklearn.preprocessing import StandardScaler
Standardizer = StandardScaler()
for i in range(len(train_features)):
    train_features[i] = (train_features[i].astype('float32') - 125.0)/125.0
for i in range(len(valid_features)):
    valid_features[i] = (valid_features[i].astype('float32') - 125.0)/125.0
for i in range(len(test_features)):
    test_features[i] = (test_features[i].astype('float32') - 125.0)/125.0

In [None]:
# save the data
import numpy as np
np.save('train_features.npy',train_features)
np.save('train_labels.npy',train_labels)
np.save('valid_features.npy',valid_features)
np.save('valid_labels.npy',valid_labels)
np.save('test_features.npy',test_features)
np.save('test_labels.npy',test_labels)

In [None]:
# using load command to load the data
# this step is not necessary if you have already loaded the data
import numpy as np
train_features = np.load('train_features.npy')
train_labels = np.load('train_labels.npy')
valid_features = np.load('valid_features.npy')
valid_labels = np.load('valid_labels.npy')
test_features = np.load('test_features.npy')
test_labels = np.load('test_labels.npy')

### autoencoder

首先建立网络进行自编码，提取数据更清晰的特征。

下面使用TensorFlow建立一个自编码器autoencoder。

In [None]:
import tensorflow as tf
inputs_ = tf.placeholder(tf.float32, (None, 224, 224, 3), name='inputs')
targets_ = tf.placeholder(tf.float32, (None, 224, 224, 3), name='targets')

# Encoder
conv1 = tf.layers.conv2d(inputs_, 16, (3,3), padding='same', activation=tf.nn.relu)
# 224x224x16
maxpool1 = tf.layers.max_pooling2d(conv1, (2,2), (2,2), padding='same')
# 112x112x16
conv2 = tf.layers.conv2d(maxpool1, 8, (3,3), padding='same', activation=tf.nn.relu)
# 112x112x8
maxpool2 = tf.layers.max_pooling2d(conv2, (2,2), (2,2), padding='same')
# 56x56x8
conv3 = tf.layers.conv2d(maxpool2, 8, (3,3), padding='same', activation=tf.nn.relu)
# 56x56x8
encoded = tf.layers.max_pooling2d(conv3, (2,2), (2,2), padding='same')
# 28x28x8

# Decoder
upsample1 = tf.image.resize_nearest_neighbor(encoded, (56,56))
# 56x56x8
conv4 = tf.layers.conv2d(upsample1, 8, (3,3), padding='same', activation=tf.nn.relu)
# 56x56x8
upsample2 = tf.image.resize_nearest_neighbor(conv4, (112,112))
# 112x112x8
conv5 = tf.layers.conv2d(upsample2, 8, (3,3), padding='same', activation=tf.nn.relu)
# 56x56x8
upsample3 = tf.image.resize_nearest_neighbor(conv5, (224,224))
# 224x224x8
conv6 = tf.layers.conv2d(upsample3, 16, (3,3), padding='same', activation=tf.nn.relu)
# 224x224x16

logits = tf.layers.conv2d(conv6, 3, (3,3), padding='same', activation=None)
# 224x224x3

decoded = tf.nn.sigmoid(logits, name='decoded')

loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets_, logits=logits)
cost = tf.reduce_mean(loss)
opt = tf.train.AdamOptimizer(0.001).minimize(cost)

In [None]:
with tf.Session() as sess:
    epochs = 5
    batch_size = 200
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(len(train_features)//batch_size):
            if ii == len(train_features)//batch_size -1:
                batch = train_features[-(len(train_features)-(ii+1)*batch_size):]
            else:
                batch = train_features[ii*batch_size:(ii+1)*batch_size]
            batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: batch,
                                                             targets_: batch})
            print("Epoch: {}/{}...".format(e+1, epochs),
                  "Training loss: {:.4f}".format(batch_cost))

In [None]:
# save the model
saver = tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
    saver.save(sess,'autoencoder.ckpt')

In [None]:
# visualize the results first
fig, axes = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(20,4))
imgs = train_features[:10]
reconstructed = sess.run(decoded, feed_dict={inputs_: imgs})

for images, row in zip([imgs, reconstructed], axes):
    for img, ax in zip(images, row):
        ax.imshow(img.reshape((112, 112)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)

In [None]:
# get results from autoencoder
with tf.Session() as sess:
    train_features = sess.run(decoded, feed_dict = {inputs_: train_features})
    valid_features = sess.run(decoded, feed_dict = {inputs_: valid_features})
    test_features = sess.run(decoded, feed_dict = {inputs_: test_features})

数据增强

In [None]:
# data augmentation
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
    zoom_range = 0.2,
    horizontal_flip = True,
    shear_range = 0.2,
    fill_mode = 'nearest'
)
datagen.fit(train_features)
datagen.fit(valid_features)

模型的初步训练也可以使用EarlyStopping，从而加速训练。

In [None]:
# 模型建立
from keras.models import Model
from keras.applications import vgg16
import tensorflow as tf
from keras.callbacks import ModelCheckpoint
from keras.optimizers import SGD

# 首先要对模型进行预训练，即固定VGG16前面的权重，对后面部分进行训练
model = vgg16.VGG16(weights = 'imagenet',include_top = False, input_shape=(224, 224, 3))
print('Model loaded')
for layer in model.layers:
    layer.trainable = False

In [None]:
# add new layer and regularizers
# 版本不符，softmax函数用tf.nn.softmax代替
from keras import regularizers
from keras.layers import Dense, Flatten, Dropout
import tensorflow as tf
output_shape = train_labels.shape
input_shape  = train_features.shape
def add_new_layer(model):
    x = model.output
    x = Flatten()(x)
    x = Dense(122, activation = 'relu',kernel_regularizer = regularizers.l1(0.01))(x)
    x = Dropout(0.5)(x)
    x = Dense(50, activation = 'relu',kernel_regularizer = regularizers.l1(0.01))(x)
    x = Dropout(0.5)(x)
    x = Dense(output_shape[1], activation = tf.nn.softmax)(x)
    model = Model(input = model.input, output = x)
    return model
def freeze_all_model(model):
    for layer in model.layers:
        layer.trainable = False

In [None]:
#setting checkpointer and early stopping
from keras.callbacks import ModelCheckpoint,EarlyStopping

checkpointer = ModelCheckpoint(
                filepath = 'Cancer_best_weights.hdf5', 
                verbose=1,
                save_best_only=True)
stopper = EarlyStopping(
                monitor = 'val_acc',
                patience = 2,
                min_delta = 0.0003,
                mode = 'max')

In [None]:
# set up model
model= add_new_layer(model)
# setting hyperparameters
batch_size = 32
epochs = 10
learning_rate = 0.01

sgd = SGD(lr=learning_rate, decay=1e-4, momentum=0.9, nesterov=True)
model.compile(optimizer = 'sgd',loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()

In [None]:
# train model
model_history = model.fit_generator(datagen.flow(
                        train_features,
                        train_labels,
                        batch_size = batch_size),
                    steps_per_epoch = train_features.shape[0]//batch_size,
                    callbacks = [checkpointer,stopper],
                    validation_data=[valid_features, valid_labels],
                    epochs = epochs,
                    shuffle = True,
                    verbose = 2)

In [None]:
# visualize the loss and accuracy to find a perfect point
# summarize history for accuracy
import matplotlib.pyplot as plt
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

### 初始训练完模型后

很容易看出，训练过程中虽然验证集损失一直在减少，但其正确率没有得到提高，很可能是因为模型受局限于卷积层固定的权重。

这主要是因为我们只训练了最后几层模型，而没有对前面的层进行训练，现在开始结合VGG16的顶层卷积层来训练模型。

In [None]:
# train 
def fine_tune(model):
    for layer in model.layers[:-11]:
        layer.trainable = False
    for layer in model.layers[-11:]:
        layer.trainable = True

In [None]:
fine_tune(model)
# setting hyperparameters
batch_size = 32
epochs = 10
learning_rate = 0.0001

sgd = SGD(lr=learning_rate, momentum=0.9, nesterov=True)
model.compile(optimizer = 'sgd',loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()

In [None]:
train_history_2 = 