In [None]:
import numpy as np
print(np.__version__)
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import os
import random

# Set Seeds for Reproducibility

In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Get Data

In [None]:
from datasets import LoadDataset, CustomOutput
from datasets.custom_output import image_tensor, float_mask

loaded_data = LoadDataset("_data/preprocessed256_new", image_dtype=float,
                          label_dtype=float)

dataset = CustomOutput(loaded_data, image_tensor, float_mask)

In [None]:
len(dataset)

In [None]:
i_img = 42

fig, ax = plt.subplots(1,2)
plt.gray()
ax[0].imshow(dataset[i_img][0][0])
ax[1].imshow(dataset[i_img][1])
plt.show()

In [None]:
dataset[42][1].shape

In [None]:
# get good split of dataset -> dividable by batch_size
batch_size = 16
l = len(dataset)
x = l // (batch_size *6 )
split = [x * batch_size * 5, l - x * batch_size * 5]
print(split)
train_set, val_set = torch.utils.data.random_split(dataset, split)

In [None]:
dataloader_train = DataLoader(train_set, batch_size=batch_size,
                        shuffle=True, num_workers=0)#, pin_memory = True)
dataloader_val = DataLoader(val_set, batch_size=batch_size,
                        shuffle=True, num_workers=0, pin_memory = True)

# Get Network

In [None]:
from network.unet import Unet

# Get Training

In [None]:
from network.Model import OurModel

In [None]:
import torch.nn as nn
#criterion = nn.CrossEntropyLoss().cuda()
criterion = nn.BCELoss().cuda() # use binary cross entropy loss!
network = Unet()
path = f"./_trainings/{datetime.now().strftime('%d-%m_%H-%M')}"
if os.path.exists(path):
    print("PATH already exists")
else:
    print(f"Make {path} directory")
    os.makedirs(path)

In [None]:
Model = OurModel(name = "unet", network=network, criterion=criterion, path_dir=path, lr=0.001,
                 batch_size = batch_size, verbose = True, segmentation=True)
Model.save_configuration()

In [None]:
Model.train(100, dataloader_train, validate = True,
            dataloader_val=dataloader_val, save_observables = True)

In [None]:
#Model.load_weights(f"/Unet_first_try_e50.ckpt")

In [None]:
"""for x,y in dataloader_val:
    Model.network.eval()
    y_hat = Model.network(x.float().cuda())
    break

y_hat = y_hat.cpu()
y_hat = y_hat.detach().numpy()
y = y.numpy()
fig, ax = plt.subplots(len(y_hat),3, figsize=(10, 100))
for i in range(len(y_hat)):
    ax[i,0].imshow(y_hat[i])
    ax[i,1].imshow(np.round(y_hat[i]))
    ax[i,2].imshow(y[i])"""