In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as f
from pathlib import Path
import numpy as np
import torchvision
from torch.utils.data import DataLoader, Dataset
import os
from torchsummary import torchsummary
from PIL import Image
import cv2
import collections

In [5]:
# Local

ORIGINAL_IMAGE = 'E:/dataSets/NDI_images/20220725/20220725/Observed_Crop_200x200pix'
TARGET_IMAGE = 'E:/dataSets/NDI_images/20220725/20220725/Calculated_200x200/grayscale'

In [3]:
binarized_path = Path.joinpath(Path(TARGET_IMAGE), 'binarized')
file_names = list(Path(TARGET_IMAGE).glob('*.jpg'))
for file_name in file_names:
    file_stem = file_name.name
    img = cv2.imread(str(file_name), cv2.IMREAD_UNCHANGED)
    threshold, result = cv2.threshold(img, 90, 255, cv2.THRESH_BINARY)
    cv2.imwrite(str(Path.joinpath(binarized_path, file_stem)), result)

In [35]:
class NDIDatasetReconstruction(Dataset):
    def __init__(self):
        super(NDIDatasetReconstruction, self).__init__()
        original_images = list(Path(ORIGINAL_IMAGE).glob('*.jpg'))
        origins, targets = [], []
        to_tensor_func = torchvision.transforms.ToTensor()
        for original_image in original_images:
            origins.append(to_tensor_func(Image.open(str(original_image))))
            targets.append(to_tensor_func(Image.open(str(Path.joinpath(binarized_path, original_image.name.split('_')[0] + '.jpg')))).type(torch.long).squeeze(0))
        random_index = np.random.permutation(len(origins))
        self.origins, self.targets = [], []
        for index in random_index:
            self.origins.append(origins[index])
            self.targets.append(targets[index])

    def __getitem__(self, idx):
        return self.origins[idx], self.targets[idx]

    def __len__(self):
        return len(self.origins)

In [61]:
from Unet_model import UNet, dice_loss


model = UNet(1, 2, use_BN=True)
device = torch.device('cuda:0')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-5, weight_decay=1e-8, momentum=0.9)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

In [62]:
def train(net, criterion, optimizer, epochs, device):
    net.cuda(device)
    train_iter = DataLoader(NDIDatasetReconstruction(), batch_size=8, shuffle=True, drop_last=True)
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        for origin, target in train_iter:
            origin, target = origin.cuda(device), target.cuda(device)
            assert origin.shape[1] == net.n_channels
            masks_pred = net(origin)
            loss = criterion(masks_pred, target) + dice_loss(f.softmax(masks_pred, dim=1).float(), f.one_hot(target, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True)
            epoch_loss += loss.item()
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch + 1}, Loss {epoch_loss}')

In [63]:
train(model, criterion, optimizer, 50, device)

Epoch 1, Loss 9.419089198112488
Epoch 2, Loss 7.825797915458679
Epoch 3, Loss 6.613701105117798
Epoch 4, Loss 5.989596426486969
Epoch 5, Loss 5.642834305763245
Epoch 6, Loss 5.453350901603699
Epoch 7, Loss 5.297598838806152
Epoch 8, Loss 5.179667294025421
Epoch 9, Loss 5.084425330162048
Epoch 10, Loss 5.019053936004639
Epoch 11, Loss 4.949221611022949
Epoch 12, Loss 4.887756288051605
Epoch 13, Loss 4.834590494632721
Epoch 14, Loss 4.7789976596832275
Epoch 15, Loss 4.751592755317688
Epoch 16, Loss 4.743165194988251
Epoch 17, Loss 4.6976786851882935
Epoch 18, Loss 4.639142990112305
Epoch 19, Loss 4.580848932266235
Epoch 20, Loss 4.524388492107391
Epoch 21, Loss 4.473062694072723
Epoch 22, Loss 4.43853622674942
Epoch 23, Loss 4.395389378070831
Epoch 24, Loss 4.3832796812057495
Epoch 25, Loss 4.356720566749573
Epoch 26, Loss 4.322990596294403
Epoch 27, Loss 4.287537753582001
Epoch 28, Loss 4.251213788986206
Epoch 29, Loss 4.215823292732239
Epoch 30, Loss 4.184724807739258
Epoch 31, Loss 4.

KeyboardInterrupt: 

In [44]:
model.cpu()
to_tensor_func = torchvision.transforms.ToTensor()
to_image_func = torchvision.transforms.ToPILImage()
original_images = list(Path(ORIGINAL_IMAGE).glob('*.jpg'))
origin = to_tensor_func(Image.open(str(np.random.choice(original_images)))).unsqueeze(0)
result = model(origin)
torch.argmax(result, dim=1).shape

torch.Size([1, 200, 200])

In [56]:
def reconstruction_test(net, origin, target):
    to_image_func = torchvision.transforms.ToPILImage()
    preds = net(origin)
    preds = torch.argmax(preds, dim=1)
    preds = to_image_func(preds.type(torch.float))
    target = to_image_func(target)
    preds = cv2.cvtColor(np.asarray(preds), cv2.COLOR_RGB2BGR)
    target = cv2.cvtColor(np.asarray(target), cv2.COLOR_RGB2BGR)
    stacked_image = np.hstack([target, preds])
    cv2.imshow('stacked', stacked_image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

In [57]:
original_images = list(Path(ORIGINAL_IMAGE).glob('*.jpg'))
original_image = np.random.choice(original_images)
print(original_image.name.split('_')[0])
origin_data = to_tensor_func(Image.open(str(original_image))).unsqueeze(0)
target_data = to_tensor_func(Image.open(str(Path.joinpath(Path(binarized_path), original_image.name.split('_')[0] + '.jpg'))))
reconstruction_test(model, origin, target_data)

49
