In [8]:
import random

import torch
import torch.nn as nn
import torch.optim as optim

# ↓ required for docker to avoid permission errors with .cache dir
torch.hub.set_dir("cache")  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from torchvision.transforms import GaussianBlur, ColorJitter, Compose

from cp_toolbox.deep_learning.torch.generators import SegmentationGenerator
from cp_toolbox.utils import utils

from cst.torch.cst_model import train_cst


cuda


In [9]:
# dtaset parameters
input_path = "/home/schroederubuntu/projects/cp_toolbox_data/epithelium_segmentation/v1/X"
target_path = "/home/schroederubuntu/projects/cp_toolbox_data/epithelium_segmentation/v1/y"

# training parameters
epochs = 10
batch_size = 2
img_size = (512, 512)

# distortion parameters
brightness = (0.65, 1.35)
contrast = (0.8, 1.2)
saturation = (0.8, 1.2)
hue = (-0.1, 0.1)  #  -0.5 <= min <= max <= 0.5
blur_kernel = 5.
blur_sigma = 2.


tile_paths = utils.list_file_paths(input_path + "/2", [".png"])
tile_paths2 = random.sample(
    utils.list_file_paths(input_path + "/1", [".png"]), 0)
tile_paths = tile_paths + tile_paths2

### filtering for p16, cd3 or cd8 or whatever
stain = "p16"
# tile_paths = [i for i in tile_paths if stain in i]  # to use all, just comment here

train_tile_paths, \
val_tile_paths, _ = utils.train_val_test_split(
    tile_paths, proportion_train=0.8, proportion_val=0.2)


print("Number of tiles: " + str(len(tile_paths)))
print("Train: {} - Validation: {}".format(
    len(train_tile_paths), len(val_tile_paths)))

Number of tiles: 116
Train: 92 - Validation: 23


In [10]:
""" image data generators / loaders """
train_dataset = SegmentationGenerator(
    batch_size=batch_size,
    img_size=img_size,
    image_paths=train_tile_paths,
    input_path=input_path,
    target_path=target_path,
    shuffle=True,
    rotate=True
    
)
 # drop_last avoids risk of last batch being n=1, which makes training loop fail
train_loader = train_dataset.data_loader(drop_last=True) 

val_dataset = SegmentationGenerator(
#     batch_size=batch_size,
    batch_size=1,
    img_size=img_size,
    image_paths=val_tile_paths,
    input_path=input_path,
    target_path=target_path,
    shuffle=False,
    rotate=True
    
)
 # drop_last avoids risk of last batch being n=1, which makes training loop fail
val_loader = val_dataset.data_loader(drop_last=True) 

In [11]:
""" distortion layer """
dist_layer = Compose([
    ColorJitter(brightness=brightness, contrast=contrast, 
                saturation=saturation, hue=hue), 
    GaussianBlur(blur_kernel, blur_sigma)
])

In [12]:
""" load model """
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', 
                       weights=None, num_classes=1)
# ↓ add sigmoid layer in classifier block
model.classifier.add_module("sigmoid", nn.Sigmoid())  

""" optimizer """
opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Using cache found in cache/pytorch_vision_v0.10.0


In [13]:
len(train_loader)

46

In [14]:
train_cst(
    model=model,
    train_loader=train_loader,
    device=device,
    optimizer=opt,
    dist_layer=dist_layer,
    val_loader=val_loader,
    alpha=2,
    epochs=epochs,
    model_save_path = "testing",
    model_base_name= "new_model_test"
)

Epoch 0 train - batch 15/46 - cst_loss: 0.716 - l0: 0.693 - lstab: 0.006:  33%|████████▊                  | 15/46 [00:06<00:13,  2.36it/s]


KeyboardInterrupt: 