In [1]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import json
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import sys
from pathlib import Path
module_path = Path('.').resolve().parent
if str(module_path) not in sys.path:
    sys.path.append(str(module_path))

In [2]:
TOP_S = 100
BATCH_SIZE = 32
EPOCHS = range(1)

In [3]:
""" Get data generator """
from src.datagens import DataSource, WSIDatagen
from src.config import SLIDE_DIR

DATASET_PATH = Path('/mnt/data/crc_ml/data/processed/Prostata/level1/r512px/c256px/t512px/no_overlap/datasets/1602530237.h5')
CONFIG_PATH = Path('/mnt/data/crc_ml/configs/level1.json')

with CONFIG_PATH.open('r') as config_json:
    config = json.load(config_json)
config['batch_size'] = BATCH_SIZE

train_ds = DataSource()
train_ds.add_dataset(DATASET_PATH, 'train')
test_ds = DataSource()
test_ds.add_dataset(DATASET_PATH, 'test')

dg = WSIDatagen(config, slide_dir=SLIDE_DIR)
train_gen = dg.get_sequential_generator(train_ds)
test_gen = dg.get_sequential_generator(test_ds)

In [4]:
""" Create model """
from src.training.models import PretrainedModel
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import SGD, RMSprop, Adam
from tensorflow.keras.losses import BinaryCrossentropy

model = PretrainedModel(config, conv_net=VGG16).model
loss_fn = BinaryCrossentropy(from_logits=False)
optimizer = Adam(1e-4)

[2020-10-28 02:39:01][INFO ][models] Applying L2 regularizer.
[2020-10-28 02:39:03][INFO ][models] Building vgg16 model.
[2020-10-28 02:39:03][INFO ][models] Model input size: [512, 512, 3]


In [None]:
def forward_pass(gen):
    # Initialize aggregator
    results = tf.zeros(shape=(0,1))
    for batch in tqdm(gen):
        results = tf.concat([results, model(batch)], axis=0)
        
    _, idx = tf.math.top_k(tf.reshape(results,[-1]), k=TOP_S, sorted=False)
    return idx

@tf.function
def train_step(x, y, w):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)
        loss_value = loss_fn(y, y_pred, sample_weight=w)
    grads = tape.gradient(loss_value, model.trainable_weights)
    return grads, loss_value
        
    
for epoch in EPOCHS:
    print(f'Epoch: {epoch}', flush=True)
    
    # Process all slides
    for slide_idx in range(len(test_gen.sampler.datasource)):
        grads = None
        start_time = time.time()

        # Prepare slide information
        slide_name = test_gen.prepare_slide(slide_idx)
        slide_label = int(slide_name[-1])
        
        # Get top predictions
        test_gen.batch_size = BATCH_SIZE
        top_indices = forward_pass(test_gen)
        
        test_gen.batch_size = 1
        for batch_id in tqdm(top_indices):
            # Prepare inputs
            x, _ = test_gen[batch_id]
            if slide_label == 0:
                y  = tf.zeros((x.shape[0], 1))
                w = tf.ones((x.shape[0], 1)) * 0.61
            else:
                y  = tf.ones((x.shape[0], 1))
                w = tf.ones((x.shape[0], 1)) * 2.87
            
            # Accumulate gradients
            if grads is None:
                grads, loss_val = train_step(x, y, w)
            else:
                new_grads, loss_val = train_step(x, y, w)
                grads = [g1+g2 for g1, g2 in zip(grads, new_grads)]
                
        # Performing update
        grads = [grad / (TOP_S*len(test_gen.sampler.datasource)) for grad in grads]
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
            
        end_time = time.time()
        print(f'{slide_name} ({end_time - start_time:.4f} sec) Loss: {loss_val:.2f}')

Epoch: 0


100%|██████████| 1808/1808 [02:33<00:00, 11.79it/s]
100%|██████████| 100/100 [00:08<00:00, 11.41it/s]


TP-2019_6785-06-0 (162.5075 sec) Loss: 0.17


100%|██████████| 2418/2418 [03:40<00:00, 10.98it/s]
100%|██████████| 100/100 [00:09<00:00, 10.92it/s]


TP-2019_6786-06-0 (229.7466 sec) Loss: 0.08


100%|██████████| 2852/2852 [04:13<00:00, 11.26it/s]
100%|██████████| 100/100 [00:09<00:00, 10.44it/s]


TP-2019_7207-06-1 (263.1209 sec) Loss: 32.38


100%|██████████| 2316/2316 [03:31<00:00, 10.95it/s]
100%|██████████| 100/100 [00:09<00:00, 10.11it/s]


TP-2019_6785-13-1 (221.6510 sec) Loss: 3.88


100%|██████████| 3768/3768 [05:56<00:00, 10.57it/s]
100%|██████████| 100/100 [00:10<00:00,  9.95it/s]
  0%|          | 0/1076 [00:00<?, ?it/s]

TP-2019_6786-08-0 (366.8641 sec) Loss: 0.34


100%|██████████| 1076/1076 [01:31<00:00, 11.77it/s]
100%|██████████| 100/100 [00:08<00:00, 12.40it/s]


TP-2019_6785-01-0 (99.5939 sec) Loss: 0.95


100%|██████████| 3584/3584 [05:26<00:00, 10.99it/s]
100%|██████████| 100/100 [00:10<00:00,  9.67it/s]


TP-2019_6786-09-0 (336.7472 sec) Loss: 0.40


100%|██████████| 4539/4539 [08:40<00:00,  8.72it/s]
100%|██████████| 100/100 [00:12<00:00,  8.20it/s]
  0%|          | 0/1176 [00:00<?, ?it/s]

TP-2019_6786-07-0 (533.2962 sec) Loss: 0.85


100%|██████████| 1176/1176 [01:43<00:00, 11.31it/s]
100%|██████████| 100/100 [00:10<00:00,  9.60it/s]
  0%|          | 0/2107 [00:00<?, ?it/s]

TP-2019_7207-11-0 (114.5120 sec) Loss: 0.54


100%|██████████| 2107/2107 [03:17<00:00, 10.69it/s]
100%|██████████| 100/100 [00:09<00:00, 10.39it/s]
  0%|          | 0/1870 [00:00<?, ?it/s]

TP-2019_7207-10-1 (206.9846 sec) Loss: 1.31


100%|██████████| 1870/1870 [02:27<00:00, 12.67it/s]
100%|██████████| 100/100 [00:08<00:00, 11.69it/s]


TP-2019_7488-03-1 (156.2842 sec) Loss: 0.42


100%|██████████| 2322/2322 [03:21<00:00, 11.53it/s]
100%|██████████| 100/100 [00:09<00:00, 10.89it/s]
  0%|          | 0/2205 [00:00<?, ?it/s]

TP-2019_2941-05-0 (210.9257 sec) Loss: 0.35


100%|██████████| 2205/2205 [03:23<00:00, 10.83it/s]
100%|██████████| 100/100 [00:10<00:00,  9.65it/s]


TP-2019_2623-06-0 (214.2351 sec) Loss: 0.45


100%|██████████| 2781/2781 [04:12<00:00, 10.99it/s]
100%|██████████| 100/100 [00:12<00:00,  8.00it/s]


TP-2019_7333-03-1 (265.7073 sec) Loss: 0.77


100%|██████████| 1714/1714 [02:38<00:00, 10.80it/s]
100%|██████████| 100/100 [00:09<00:00, 10.07it/s]
  0%|          | 0/2199 [00:00<?, ?it/s]

TP-2019_6785-14-1 (168.8890 sec) Loss: 0.75


100%|██████████| 2199/2199 [03:12<00:00, 11.45it/s]
100%|██████████| 100/100 [00:08<00:00, 11.85it/s]
  0%|          | 0/1227 [00:00<?, ?it/s]

TP-2019_7333-01-1 (200.7013 sec) Loss: 0.55


100%|██████████| 1227/1227 [01:38<00:00, 12.47it/s]
100%|██████████| 100/100 [00:08<00:00, 12.08it/s]
  0%|          | 0/1546 [00:00<?, ?it/s]

TP-2019_6786-01-1 (106.7852 sec) Loss: 0.77


100%|██████████| 1546/1546 [02:10<00:00, 11.81it/s]
100%|██████████| 100/100 [00:08<00:00, 11.42it/s]
  0%|          | 0/1179 [00:00<?, ?it/s]

TP-2019_6785-09-0 (139.8358 sec) Loss: 1.06


100%|██████████| 1179/1179 [01:35<00:00, 12.30it/s]
100%|██████████| 100/100 [00:10<00:00,  9.32it/s]


TP-2019_7207-08-0 (106.7181 sec) Loss: 1.54


100%|██████████| 2694/2694 [03:49<00:00, 11.73it/s]
100%|██████████| 100/100 [00:08<00:00, 11.14it/s]


TP-2019_7488-04-1 (239.0238 sec) Loss: 0.51


100%|██████████| 3026/3026 [04:26<00:00, 11.36it/s]
100%|██████████| 100/100 [00:08<00:00, 11.31it/s]


TP-2019_7207-04-1 (275.6143 sec) Loss: 0.36


100%|██████████| 2670/2670 [03:54<00:00, 11.40it/s]
100%|██████████| 100/100 [00:08<00:00, 11.22it/s]
  0%|          | 0/954 [00:00<?, ?it/s]

TP-2019_7207-13-1 (243.3243 sec) Loss: 1.41


100%|██████████| 954/954 [01:15<00:00, 12.60it/s]
100%|██████████| 100/100 [00:08<00:00, 12.33it/s]
  0%|          | 0/1877 [00:00<?, ?it/s]

TP-2019_7207-02-0 (83.9550 sec) Loss: 1.40


100%|██████████| 1877/1877 [02:39<00:00, 11.76it/s]
100%|██████████| 100/100 [00:08<00:00, 11.42it/s]


TP-2019_7362-08-0 (168.5639 sec) Loss: 1.42


100%|██████████| 1824/1824 [02:40<00:00, 11.38it/s]
100%|██████████| 100/100 [00:10<00:00,  9.51it/s]


TP-2019_6887-04-0 (170.9796 sec) Loss: 1.31


100%|██████████| 2634/2634 [04:23<00:00, 10.01it/s]
100%|██████████| 100/100 [00:08<00:00, 11.44it/s]


TP-2019_7333-04-1 (272.1535 sec) Loss: 0.36


100%|██████████| 1935/1935 [02:38<00:00, 12.21it/s]
100%|██████████| 100/100 [00:09<00:00, 10.71it/s]
  0%|          | 0/1362 [00:00<?, ?it/s]

TP-2019_7333-07-0 (168.0300 sec) Loss: 0.78


100%|██████████| 1362/1362 [02:06<00:00, 10.77it/s]
100%|██████████| 100/100 [00:09<00:00, 10.20it/s]


TP-2019_6887-02-0 (136.4520 sec) Loss: 1.06


100%|██████████| 3419/3419 [05:49<00:00,  9.79it/s]
100%|██████████| 100/100 [00:10<00:00,  9.14it/s]


TP-2019_2623-09-0 (360.6613 sec) Loss: 1.73


100%|██████████| 1959/1959 [03:03<00:00, 10.66it/s]
100%|██████████| 100/100 [00:10<00:00,  9.65it/s]


TP-2019_7362-06-1 (194.4647 sec) Loss: 0.71


100%|██████████| 3178/3178 [04:56<00:00, 10.73it/s]
100%|██████████| 100/100 [00:09<00:00, 10.36it/s]


TP-2019_2623-12-1 (306.0107 sec) Loss: 0.68


100%|██████████| 1505/1505 [02:00<00:00, 12.47it/s]
100%|██████████| 100/100 [00:08<00:00, 11.68it/s]


TP-2019_6887-06-0 (129.4614 sec) Loss: 0.90


100%|██████████| 1838/1838 [02:34<00:00, 11.93it/s]
100%|██████████| 100/100 [00:08<00:00, 12.32it/s]


TP-2019_6887-08-0 (162.4408 sec) Loss: 0.56


100%|██████████| 3181/3181 [04:51<00:00, 10.93it/s]
100%|██████████| 100/100 [00:10<00:00,  9.25it/s]
  0%|          | 0/1567 [00:00<?, ?it/s]

TP-2019_7207-07-1 (302.1279 sec) Loss: 1.03


100%|██████████| 1567/1567 [02:31<00:00, 10.36it/s]
100%|██████████| 100/100 [00:10<00:00,  9.47it/s]


TP-2019_7333-08-0 (161.9572 sec) Loss: 0.84


100%|██████████| 2101/2101 [03:07<00:00, 11.22it/s]
100%|██████████| 100/100 [00:08<00:00, 11.55it/s]


TP-2019_6887-03-0 (196.1786 sec) Loss: 0.65


100%|██████████| 2800/2800 [04:14<00:00, 11.02it/s]
100%|██████████| 100/100 [00:09<00:00, 10.23it/s]


TP-2019_7207-12-1 (264.2006 sec) Loss: 0.22


100%|██████████| 3763/3763 [06:03<00:00, 10.36it/s]
100%|██████████| 100/100 [00:10<00:00,  9.75it/s]


TP-2019_6786-10-0 (373.6433 sec) Loss: 0.53


100%|██████████| 1873/1873 [02:42<00:00, 11.54it/s]
100%|██████████| 100/100 [00:09<00:00, 10.75it/s]


TP-2019_7333-10-0 (171.8886 sec) Loss: 0.89


100%|██████████| 3128/3128 [05:57<00:00,  8.74it/s]
100%|██████████| 100/100 [00:11<00:00,  8.77it/s]
  0%|          | 0/2307 [00:00<?, ?it/s]

TP-2019_2824-07-0 (369.5599 sec) Loss: 0.67


100%|██████████| 2307/2307 [03:29<00:00, 10.99it/s]
100%|██████████| 100/100 [00:07<00:00, 12.61it/s]


TP-2019_6786-13-0 (218.0147 sec) Loss: 0.50


100%|██████████| 3302/3302 [05:18<00:00, 10.38it/s]
100%|██████████| 100/100 [00:08<00:00, 11.47it/s]
  0%|          | 0/2265 [00:00<?, ?it/s]

TP-2019_7488-02-1 (327.1475 sec) Loss: 0.32


100%|██████████| 2265/2265 [03:15<00:00, 11.58it/s]
100%|██████████| 100/100 [00:08<00:00, 11.72it/s]
  0%|          | 0/1434 [00:00<?, ?it/s]

TP-2019_6887-07-1 (204.2756 sec) Loss: 1.83


100%|██████████| 1434/1434 [02:14<00:00, 10.70it/s]
100%|██████████| 100/100 [00:09<00:00, 11.07it/s]
  0%|          | 0/984 [00:00<?, ?it/s]

TP-2019_2941-01-1 (143.2508 sec) Loss: 0.38


100%|██████████| 984/984 [01:33<00:00, 10.56it/s]
100%|██████████| 100/100 [00:08<00:00, 11.66it/s]


TP-2019_6887-09-0 (101.8862 sec) Loss: 1.27


100%|██████████| 2198/2198 [03:12<00:00, 11.40it/s]
100%|██████████| 100/100 [00:10<00:00,  9.97it/s]
  0%|          | 0/1266 [00:00<?, ?it/s]

TP-2019_6887-10-0 (203.0226 sec) Loss: 0.62


100%|██████████| 1266/1266 [01:45<00:00, 12.00it/s]
100%|██████████| 100/100 [00:08<00:00, 12.14it/s]


TP-2019_6887-01-0 (113.8982 sec) Loss: 0.94


100%|██████████| 2963/2963 [04:57<00:00,  9.97it/s]
100%|██████████| 100/100 [00:09<00:00, 10.00it/s]


TP-2019_7207-14-1 (307.5084 sec) Loss: 1.38


100%|██████████| 1702/1702 [02:19<00:00, 12.17it/s]
100%|██████████| 100/100 [00:09<00:00, 11.02it/s]
  0%|          | 0/1482 [00:00<?, ?it/s]

TP-2019_7333-02-1 (149.2048 sec) Loss: 0.44


100%|██████████| 1482/1482 [02:00<00:00, 12.30it/s]
100%|██████████| 100/100 [00:07<00:00, 12.53it/s]
  0%|          | 0/2212 [00:00<?, ?it/s]

TP-2019_2824-01-1 (128.5756 sec) Loss: 1.18


100%|██████████| 2212/2212 [03:13<00:00, 11.42it/s]
100%|██████████| 100/100 [00:10<00:00,  9.56it/s]


TP-2019_6785-03-0 (204.4465 sec) Loss: 0.62


100%|██████████| 2548/2548 [04:07<00:00, 10.30it/s]
100%|██████████| 100/100 [00:09<00:00, 10.35it/s]
  0%|          | 0/878 [00:00<?, ?it/s]

TP-2019_2824-13-0 (257.3498 sec) Loss: 0.79


100%|██████████| 878/878 [01:17<00:00, 11.35it/s]
100%|██████████| 100/100 [00:08<00:00, 11.78it/s]
  0%|          | 0/2014 [00:00<?, ?it/s]

TP-2019_6785-11-0 (85.9659 sec) Loss: 1.07


100%|██████████| 2014/2014 [03:08<00:00, 10.68it/s]
100%|██████████| 100/100 [00:11<00:00,  8.72it/s]


TP-2019_7362-03-0 (200.2371 sec) Loss: 1.37


100%|██████████| 1493/1493 [02:23<00:00, 10.43it/s]
100%|██████████| 100/100 [00:08<00:00, 11.86it/s]
  0%|          | 0/2363 [00:00<?, ?it/s]

TP-2019_7333-09-0 (151.8564 sec) Loss: 0.69


  4%|▍         | 95/2363 [00:09<03:09, 11.98it/s]