In [1]:
import re
import json
import torch
import torch.nn as nn
from network.unet import Unet
from network.feature_extractor import ResNet, ResnetOriginal
from network.full_model import EndNetwork, FullModel

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

from datasets import LoadDataset, CustomOutput, Knit
from datasets.custom_output import image_tensor, study_label_5

from train.full import FullTrainingCLI

from network.helper_functions import get_confusion_matrix

from utils.device import device

In [2]:
torch.cuda.is_available()

True

In [3]:
labelsn = 0
args_list = f'--do-batch-norm --epochs=200 --batch-size=100 --unet-weights=_trainings/22-09_17-30_aDAOG_cB3D_b24_e200_BN_lr0.01_lrsp10/unet_e200.ckpt --resnet-weights=_weights/balanced_BCEwLLoss_labels{labelsn}_resnet18_e12.ckpt --resnet-fc-cutoff=9 --learning-rate=0.0001 --use-lr-scheduler --lr-sch-patience=5 --resnet-no-sigmoid-activation --criterion=CEBAL --resnet-trainable --cuda-device=1'.split(' ')
args_list

['--do-batch-norm',
 '--epochs=200',
 '--batch-size=100',
 '--unet-weights=_trainings/22-09_17-30_aDAOG_cB3D_b24_e200_BN_lr0.01_lrsp10/unet_e200.ckpt',
 '--resnet-weights=_weights/balanced_BCEwLLoss_labels0_resnet18_e12.ckpt',
 '--resnet-fc-cutoff=9',
 '--learning-rate=0.0001',
 '--use-lr-scheduler',
 '--lr-sch-patience=5',
 '--resnet-no-sigmoid-activation',
 '--criterion=CEBAL',
 '--resnet-trainable',
 '--cuda-device=1']

In [4]:
cli = FullTrainingCLI()
args = cli.get_args(*args_list)
args

Namespace(absolute_training_size=5200, adam_regul_factor=0, augmentation=<Augmentation.NA: <trafo.compose.Compose object at 0x7f1f2ddc4650>>, batch_size=100, criterion=<Criterion.CEBAL: CrossEntropyLoss()>, cuda_device=1, do_batch_norm=True, epochs=200, feature_shape=512, get_abbrev_only=False, get_cuda_device_count_only=False, get_path_only=False, learning_rate=0.0001, lr_sch_patience=5, no_drop_last=False, no_dropout=False, path=None, path_prefix='_full_training', resnet_fc_cutoff=9, resnet_no_sigmoid_activation=True, resnet_out_shape=None, resnet_trainable=True, resnet_weights='_weights/balanced_BCEwLLoss_labels0_resnet18_e12.ckpt', unet_trainable=False, unet_weights='_trainings/22-09_17-30_aDAOG_cB3D_b24_e200_BN_lr0.01_lrsp10/unet_e200.ckpt', use_lr_scheduler=True)

In [5]:
# Get Networks
unet = Unet(batch_norm=args.do_batch_norm)
unet.load_state_dict(torch.load(args.unet_weights,
                                map_location=device))

# Get ResNet
pattern = re.compile(r"(_e(\d+)\.ckpt)$")
load_epoch = int(pattern.search(args.resnet_weights).group(2))
print("Attempting to load resnet from epoch", load_epoch)
resnet_config_file = pattern.sub("_net_config.json",
                                 args.resnet_weights)

with open(resnet_config_file) as f:
    resnet_config = json.load(f)

sigmoid_activation = args.resnet_no_sigmoid_activation is not True
if resnet_config["network"] == "ResNet":
    dims = [int(n) for n
            in re.search(r"resnet(\d+)", resnet_config_file).group(1)]
    print("trying to load", f"dims={dims}",
          f"out_shape={args.resnet_out_shape}")
    resnet = ResNet(dims, out_shape=args.resnet_out_shape,
                    sigmoid_activation=sigmoid_activation)
    resnet.load_state_dict(torch.load(args.resnet_weights,
                                      map_location=device))
    children = list(resnet.end.children())
    # cut off everything from last Linear layer:
    for i, child in enumerate(reversed(children)):
        if isinstance(child, nn.Linear):
            resnet.end = nn.Sequential(*children[:-1-i])
            break
    else:
        raise RuntimeError("No nn.Linear found in ResNet")
elif resnet_config["network"] == "ResnetOriginal":
    if args.resnet_fc_cutoff is None:
        raise RuntimeError("Please specify --resnet-fc-cutoff=n")
    resnet_type = resnet_config["type"]
    shapes = list(resnet_config["shapes"])
    trainable_level = int(resnet_config["trainable_level"])
    trainable_resnet = bool(resnet_config["trainable_resnet"])
    resnet = ResnetOriginal(type=resnet_type, shapes=shapes,
                            trainable_resnet=trainable_resnet,
                            trainable_level=trainable_level,
                            sigmoid_activation=sigmoid_activation)
    resnet.load_state_dict(torch.load(args.resnet_weights,
                                      map_location=device))
    resnet.fc.block = nn.Sequential(*list(resnet.fc.block.children())
                                    [:-args.resnet_fc_cutoff])

# Full network
end_network = EndNetwork(features_shape=args.feature_shape,
                         use_dropout=args.no_dropout is not True)

model = FullModel(unet=unet, feature_extractor=resnet, end=end_network,
                  threshold=0.5,
                  unet_trainable=args.unet_trainable,
                  feature_extractor_trainable=args.resnet_trainable)


Attempting to load resnet from epoch 12


In [6]:
loaded_data = LoadDataset("_data/preprocessed256_new",
                                  image_dtype=float, label_dtype=float)
knit_data = Knit(loaded_data, study_csv="_data/train_study_level.csv",
                 image_csv="_data/train_image_level.csv")
dataset_plain = CustomOutput(knit_data, image_tensor, study_label_5)
dataset_aug = CustomOutput(knit_data, image_tensor, study_label_5,
                           trafo=args.augmentation.value)
dataset_aug.trafo.max_transformands = 1

l_data = len(dataset_aug)
indices = list(range(l_data))
train_size = args.absolute_training_size
val_size = l_data - train_size
print("Training: ", train_size, "Validation: ", val_size)

train_indices, val_indices = train_test_split(indices, random_state=4,
                                              train_size=train_size,
                                              test_size=val_size)

train_set = Subset(dataset_aug, train_indices)
val_set = Subset(dataset_plain, val_indices)

dataloader_train = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True, num_workers=0)
dataloader_val = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=True, num_workers=0,
                            pin_memory=True)

Training:  5200 Validation:  1134


  warn(f"Guessed {name} from {dataset.__class__.__name__}")
  warn(f"{e.__class__.__name__}: {e}")
  warn(f"Guessed {name} from {dataset.__class__.__name__}")


In [11]:
get_confusion_matrix(model.cuda(), dataloader_val) # nach 80

array([[199,  16,  79,   8,   2],
       [ 32, 348, 110,  12,  11],
       [ 44,  67,  61,  10,   4],
       [  7,  26,  31,   5,   3],
       [  4,  28,   8,   3,  16]])

In [24]:
model.load_state_dict(torch.load("_full_training/25-09_15-42_aNA_cCEBAL_b100_e200_BN_lr0.0001_lrsp5_nosig_fcc9/_e100.ckpt"))

<All keys matched successfully>

In [13]:
get_confusion_matrix(model.cuda(), dataloader_val) # nach 10

array([[197,  16,  79,  10,   2],
       [ 32, 341, 117,  12,  11],
       [ 42,  65,  65,  10,   4],
       [  7,  23,  32,   8,   2],
       [  4,  26,  10,   3,  16]])

In [14]:
get_confusion_matrix(model.cuda(), dataloader_train) # nach 10

array([[1415,    5,   12,    0,    0],
       [  17, 1994,  305,    1,   24],
       [   0,   14,  841,    0,    8],
       [   0,    0,    0,  319,    0],
       [   0,    0,    0,    1,  244]])

In [16]:
model.load_state_dict(torch.load("_full_training/25-09_10-29_aNA_cCE_b100_e200_BN_lr1e-05_lrsp10_nosig_fcc9/_e100.ckpt"))

<All keys matched successfully>

In [30]:
get_confusion_matrix(model.cuda(), dataloader_val) # nach 10

array([[196,  16,  81,   9,   2],
       [ 31, 348, 110,  13,  11],
       [ 43,  67,  62,  10,   4],
       [  7,  24,  32,   6,   3],
       [  4,  28,   8,   3,  16]])

In [31]:
get_confusion_matrix(model.cuda(), dataloader_train) # nach 10

array([[1421,    4,    7,    0,    0],
       [   4, 2123,  191,    1,   22],
       [   0,    2,  853,    0,    8],
       [   0,    0,    0,  318,    1],
       [   0,    0,    0,    0,  245]])

In [29]:
for p in model.end.parameters():
    print(torch.min(p))

tensor(-0.3334, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.2611, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0882, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0785, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0896, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0833, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0668, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0553, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0706, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0579, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0360, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0186, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0537, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0444, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0456, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0403, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0342, device='cuda:0', grad_fn=<MinBackward1>)
tensor(-0.0244, device='cuda:0'