In [None]:
import torch
import torch.nn as nn
import PIL.Image as Image
from matplotlib.pyplot import imshow

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

from datasets import LoadDataset, CustomOutput, Knit
from datasets.custom_output import image_tensor, study_label_5
from trafo.randomize.default_augmentation import default_augmentation_only_geometric

from network.feature_extractor import ResNet, ResnetOriginal
from network.full_model import EndNetwork, FullModel
from network.unet import Unet

from network.training import FullTraining, get_balanced_crossentropy_loss

In [None]:
torch.manual_seed(4)
torch.cuda.manual_seed(4)

In [None]:
BATCH_SIZE = 2
save_weights_at = "./_weights"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights_unet = "./_trainings/19-09_19-26_aDAOG_cB3D_b48_e150_BN/unet_e150.ckpt"
weights_resnet = "./_weights/resnet1111_e18.ckpt"
weights_oresnet = "./_weights/resnet18_orig_tr=6_09_21_e18.ckpt"

# Get Data

In [None]:
loaded_data = LoadDataset("_data/preprocessed256_new", image_dtype=float,
                          label_dtype=float)
knit_data = Knit(loaded_data, study_csv = "_data/train_study_level.csv", image_csv = "_data/train_image_level.csv")
dataset_plain = CustomOutput(knit_data, image_tensor, study_label_5)
default_augmentation_only_geometric.max_transformands = 1
dataset_aug = CustomOutput(knit_data, image_tensor, study_label_5, 
                           trafo= default_augmentation_only_geometric)

In [None]:
l_data = len(dataset_aug)
indices = list(range(l_data))
train_size = 5200
val_size = l_data - train_size

print("Training: ", train_size, "Validation: ", val_size)
train_indices, val_indices = train_test_split(
    indices, random_state=4, train_size=train_size, test_size=val_size
    )

train_set = Subset(dataset_aug, train_indices)
val_set = Subset(dataset_plain, val_indices)

dataloader_train = DataLoader(train_set, batch_size= BATCH_SIZE,
                              shuffle=True, num_workers=0, )
dataloader_val = DataLoader(val_set, batch_size= BATCH_SIZE,
                            shuffle=True, num_workers=0, pin_memory=True )

# Get Networks

## get Unet

In [None]:
unet = Unet(batch_norm= True)
unet.load_state_dict(torch.load(weights_unet, map_location= torch.device(device)))

## get ResNet

In [None]:
resnet = ResNet([1,1,1,1], out_shape= 14)
resnet.load_state_dict(torch.load(weights_resnet, map_location= torch.device(device)))
resnet.end = nn.Sequential(*list(resnet.end.children())[:-2])
feature_shape = 512

In [None]:
oresnet = ResnetOriginal(type = "resnet18", shapes = [512, 124, 32, 14], \
                         trainable_resnet = True, trainable_level= 6)
oresnet.load_state_dict(torch.load(weights_oresnet, map_location = torch.device(device)))
oresnet.fc.block = nn.Sequential(*list(oresnet.fc.block.children())[:-9])
feature_shape = 512

## get end

In [None]:
end_network = EndNetwork(features_shape = feature_shape)

## Full Network

In [None]:
model = FullModel(unet = unet, feature_extractor = oresnet, end = end_network, 
                threshold = 0.5, unet_trainable = False, feature_extractor_trainable = False)

In [None]:
# number trainable parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
# number parameters
sum(p.numel() for p in model.parameters())

# Get Training

In [None]:
loss = get_balanced_crossentropy_loss(train_set, verbose = True, shape = 5)

In [None]:
training = FullTraining("full_training", model, loss,
                        batch_size= BATCH_SIZE, verbose_level= 2, path_dir = save_weights_at,
                        data_trafo= None)

In [None]:
training.train(24, dataloader= dataloader_train)