# Blood vessel detection using UNET

First we divide images into adequate folders, segmenting train and test images to 128x128 windows:

In [3]:
from skimage.util import view_as_windows
from random import sample
import cv2
import shutil 
import os
import glob

TRAIN_IMG_DIR = 'data/train_images'
VAL_IMG_DIR = 'data/val_images'
TEST_IMG_DIR = 'data/test_images'

TRAIN_MASK_DIR = 'data/train_masks'
VAL_MASK_DIR = 'data/val_masks'
TEST_MASK_DIR = 'data/test_masks'

IMAGE_SIZE =  64

def create_folders():
    if not os.path.exists(TRAIN_IMG_DIR):
        os.makedirs(TRAIN_IMG_DIR)

    if not os.path.exists(VAL_IMG_DIR):
        os.makedirs(VAL_IMG_DIR)

    if not os.path.exists(TEST_IMG_DIR):
        os.makedirs(TEST_IMG_DIR)

    if not os.path.exists(TRAIN_MASK_DIR):
        os.makedirs(TRAIN_MASK_DIR)

    if not os.path.exists(VAL_MASK_DIR):
        os.makedirs(VAL_MASK_DIR)

    if not os.path.exists(TEST_MASK_DIR):
        os.makedirs(TEST_MASK_DIR)
        
def create_dataset(num_samples):
    create_folders()
    
    img_filenames = glob.glob("data/images/*")
    img_filenames.sort()
    images = [cv2.imread(img) for img in img_filenames]

    mask_filenames = glob.glob("data/manual1/*")
    mask_filenames.sort()
    masks = [cv2.imread(img, 0) for img in mask_filenames]

    # copy test images to test folder
    for i, (img_path, mask_path) in enumerate(zip(img_filenames[:5], mask_filenames[:5])):
        shutil.copy(img_path, TEST_IMG_DIR + "/" + str(i) + ".png")
        shutil.copy(mask_path, TEST_MASK_DIR + "/" + str(i) + ".png")

    # create windows for train and val images
    all_windows = []
    all_mask_windows = []
    for img, mask in zip(images[5:], masks[5:]):
        windows = view_as_windows(img, (IMAGE_SIZE, IMAGE_SIZE, 3), step=IMAGE_SIZE)
        width, height, _, _, _, _ = windows.shape
        windows = windows.flatten().reshape(width*height, IMAGE_SIZE, IMAGE_SIZE, 3)
        all_windows.extend(windows)

        mask_windows = view_as_windows(mask, (IMAGE_SIZE, IMAGE_SIZE), step=IMAGE_SIZE)
        mask_windows = mask_windows.flatten().reshape(width*height, IMAGE_SIZE, IMAGE_SIZE)
        all_mask_windows.extend(mask_windows)
    
    # sample created data
    data = list(zip(all_windows, all_mask_windows))
    after_sampling = sample(data, num_samples)
    all_windows, all_mask_windows = zip(*after_sampling)

    for i, (window, mask_window) in enumerate(zip(all_windows, all_mask_windows)):
        if i <= len(all_windows) * 0.8: # add to train set
            cv2.imwrite(TRAIN_IMG_DIR + "/" + str(i) + ".png", window)
            cv2.imwrite(TRAIN_MASK_DIR + "/" + str(i) + ".tiff", mask_window)
        else: # add to val set
            cv2.imwrite(VAL_IMG_DIR + "/" + str(i) + ".png", window)
            cv2.imwrite(VAL_MASK_DIR + "/" + str(i) + ".tiff", mask_window)
            
create_dataset(500)

Next we define UNET architecture:

In [4]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

Then we create dataset

In [13]:
from torch.utils.data import Dataset

class BloodVesselDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = glob.glob(image_dir + "/*")
        self.images.sort()
        self.masks = glob.glob(mask_dir + "/*")
        self.masks.sort()
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = self.images[index]
        mask_path = self.masks[index]
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path, 0)
        mask = mask/255
        
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        
        return image, mask

In [14]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

# HYPERPARAMETERS
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 100
NUM_WORKERS = 0
PIN_MEMORY = True
LOAD_MODEL = True


def train(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)
        
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item())

        
train_transform = A.Compose(
[
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
#     A.Rotate(limit=35, p=1.0),
#     A.HorizontalFlip(p=0.5),
#     A.VerticalFlip(p=0.1),
    ToTensorV2()
])

val_transform = A.Compose(
[
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
    ToTensorV2()
])


In [15]:
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename):
    print("===>Saving checkpoint")
    torch.save(state, filename)
    
def load_checkpoint(checkpoint, model):
    print("===>Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    
def get_loaders(
        train_dir,
        train_mask_dir,
        val_dir,
        val_mask_dir,
        batch_size,
        train_transform,
        val_transform,
        num_workers=4,
        pin_memory=True):
    
    train_ds = BloodVesselDataset(
        image_dir=train_dir,
        mask_dir=train_mask_dir,
        transform=train_transform)
    
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True)
    
    val_ds = BloodVesselDataset(
        image_dir=val_dir,
        mask_dir=val_mask_dir,
        transform=val_transform)
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False)
    
    return train_loader, val_loader

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

def save_predictions_as_imgs(
    loader, model, epoch, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(x.float(), f"{folder}img_{epoch}_{idx}.png")
        torchvision.utils.save_image(
            preds, f"{folder}pred_{epoch}_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1).float(), f"{folder}mask_{epoch}_{idx}.png")

    model.train()

In [16]:
if __name__ == '__main__':
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY
    )

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train(train_loader, model, optimizer, loss_fn, scaler)
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        save_checkpoint(checkpoint, filename=f"checkpoint_{epoch}")
        check_accuracy(val_loader, model, device=DEVICE)
        save_predictions_as_imgs(val_loader, model, epoch, device=DEVICE)

100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:23<00:00,  5.53s/it, loss=0.548]


===>Saving checkpoint
Got 375566/405504 with acc 92.62
Dice score: 0.0002562673074989772


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:18<00:00,  5.34s/it, loss=0.492]


===>Saving checkpoint
Got 374247/405504 with acc 92.29
Dice score: 0.07535925310903786


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:20<00:00,  5.40s/it, loss=0.455]


===>Saving checkpoint
Got 381410/405504 with acc 94.06
Dice score: 0.38200644253064536


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.48s/it, loss=0.405]


===>Saving checkpoint
Got 383986/405504 with acc 94.69
Dice score: 0.4669656569191685


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:16<00:00,  5.26s/it, loss=0.578]


===>Saving checkpoint
Got 384237/405504 with acc 94.76
Dice score: 0.48057238979469086


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.55s/it, loss=0.371]


===>Saving checkpoint
Got 386261/405504 with acc 95.25
Dice score: 0.4971950416528611


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:07<00:00,  2.59s/it, loss=0.443]


===>Saving checkpoint
Got 378879/405504 with acc 93.43
Dice score: 0.46933037283001233


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.56s/it, loss=0.503]


===>Saving checkpoint
Got 373838/405504 with acc 92.19
Dice score: 0.47324797563633186


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:08<00:00,  2.64s/it, loss=0.327]


===>Saving checkpoint
Got 388919/405504 with acc 95.91
Dice score: 0.5533723333359474


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:15<00:00,  2.91s/it, loss=0.347]


===>Saving checkpoint
Got 389315/405504 with acc 96.01
Dice score: 0.5697705560389749


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.54s/it, loss=0.418]


===>Saving checkpoint
Got 389534/405504 with acc 96.06
Dice score: 0.5307772209615279


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.54s/it, loss=0.536]


===>Saving checkpoint
Got 390334/405504 with acc 96.26
Dice score: 0.5988139963529411


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.53s/it, loss=0.344]


===>Saving checkpoint
Got 390415/405504 with acc 96.28
Dice score: 0.5729874459251224


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:04<00:00,  2.46s/it, loss=0.274]


===>Saving checkpoint
Got 390152/405504 with acc 96.21
Dice score: 0.5301679678503587


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:02<00:00,  2.41s/it, loss=0.281]


===>Saving checkpoint
Got 391267/405504 with acc 96.49
Dice score: 0.572175927600337


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.53s/it, loss=0.264]


===>Saving checkpoint
Got 390548/405504 with acc 96.31
Dice score: 0.5455616274677384


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.55s/it, loss=0.296]


===>Saving checkpoint
Got 390885/405504 with acc 96.39
Dice score: 0.5804533256915142


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.54s/it, loss=0.248]


===>Saving checkpoint
Got 392427/405504 with acc 96.78
Dice score: 0.6114912914563319


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:04<00:00,  2.49s/it, loss=0.298]


===>Saving checkpoint
Got 392144/405504 with acc 96.71
Dice score: 0.59062735883703


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.55s/it, loss=0.253]


===>Saving checkpoint
Got 391947/405504 with acc 96.66
Dice score: 0.5910992928000106


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.53s/it, loss=0.236]


===>Saving checkpoint
Got 391959/405504 with acc 96.66
Dice score: 0.5871432063743454


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:04<00:00,  2.48s/it, loss=0.482]


===>Saving checkpoint
Got 390675/405504 with acc 96.34
Dice score: 0.544704214527946


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:03<00:00,  2.45s/it, loss=0.192]


===>Saving checkpoint
Got 391809/405504 with acc 96.62
Dice score: 0.5850315119318891


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:06<00:00,  2.57s/it, loss=0.359]


===>Saving checkpoint
Got 391722/405504 with acc 96.60
Dice score: 0.5936034688727279


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.51s/it, loss=0.273]


===>Saving checkpoint
Got 391992/405504 with acc 96.67
Dice score: 0.5979129621351331


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:07<00:00,  2.61s/it, loss=0.241]


===>Saving checkpoint
Got 390999/405504 with acc 96.42
Dice score: 0.5697378297928085


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:04<00:00,  2.49s/it, loss=0.217]


===>Saving checkpoint
Got 392216/405504 with acc 96.72
Dice score: 0.5942492431988237


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:05<00:00,  2.54s/it, loss=0.161]


===>Saving checkpoint
Got 392283/405504 with acc 96.74
Dice score: 0.6015024024687764


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:07<00:00,  2.59s/it, loss=0.217]


===>Saving checkpoint
Got 392290/405504 with acc 96.74
Dice score: 0.5922361468218744


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:08<00:00,  2.63s/it, loss=0.267]


===>Saving checkpoint
Got 392757/405504 with acc 96.86
Dice score: 0.6163566257437447


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:10<00:00,  2.73s/it, loss=0.349]


===>Saving checkpoint
Got 390450/405504 with acc 96.29
Dice score: 0.5332459917429689


100%|███████████████████████████████████████████████████████████████████████| 26/26 [02:21<00:00,  5.43s/it, loss=0.27]


===>Saving checkpoint
Got 391786/405504 with acc 96.62
Dice score: 0.5808116852445601


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.49s/it, loss=0.155]


===>Saving checkpoint
Got 392155/405504 with acc 96.71
Dice score: 0.5942904691319424


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:21<00:00,  5.45s/it, loss=0.164]


===>Saving checkpoint
Got 392101/405504 with acc 96.69
Dice score: 0.6015361128029323


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.48s/it, loss=0.278]


===>Saving checkpoint
Got 391149/405504 with acc 96.46
Dice score: 0.5636303714094765


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.47s/it, loss=0.123]


===>Saving checkpoint
Got 392661/405504 with acc 96.83
Dice score: 0.5985104431554042


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:19<00:00,  5.35s/it, loss=0.114]


===>Saving checkpoint
Got 393037/405504 with acc 96.93
Dice score: 0.6332036548140951


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:21<00:00,  5.44s/it, loss=0.205]


===>Saving checkpoint
Got 389923/405504 with acc 96.16
Dice score: 0.5299055137574153


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:21<00:00,  5.44s/it, loss=0.219]


===>Saving checkpoint
Got 392407/405504 with acc 96.77
Dice score: 0.6000832702916756


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.47s/it, loss=0.128]


===>Saving checkpoint
Got 391918/405504 with acc 96.65
Dice score: 0.5856994313040158


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:21<00:00,  5.45s/it, loss=0.129]


===>Saving checkpoint
Got 391902/405504 with acc 96.65
Dice score: 0.6151175271675952


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:20<00:00,  5.41s/it, loss=0.408]


===>Saving checkpoint
Got 391828/405504 with acc 96.63
Dice score: 0.5822364761783927


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.49s/it, loss=0.194]


===>Saving checkpoint
Got 392168/405504 with acc 96.71
Dice score: 0.6021419649675404


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:23<00:00,  5.51s/it, loss=0.103]


===>Saving checkpoint
Got 392315/405504 with acc 96.75
Dice score: 0.5962836886901035


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:22<00:00,  5.47s/it, loss=0.138]


===>Saving checkpoint
Got 391395/405504 with acc 96.52
Dice score: 0.6206789683389137


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:18<00:00,  5.34s/it, loss=0.142]


===>Saving checkpoint
Got 393283/405504 with acc 96.99
Dice score: 0.6416632421209424


100%|█████████████████████████████████████████████████████████████████████| 26/26 [02:19<00:00,  5.37s/it, loss=0.0946]


===>Saving checkpoint
Got 389877/405504 with acc 96.15
Dice score: 0.603485828434309


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:24<00:00,  5.54s/it, loss=0.266]


===>Saving checkpoint
Got 391572/405504 with acc 96.56
Dice score: 0.5691352295712725


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:25<00:00,  5.61s/it, loss=0.262]


===>Saving checkpoint
Got 391419/405504 with acc 96.53
Dice score: 0.5672372081853471


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:26<00:00,  5.62s/it, loss=0.139]


===>Saving checkpoint
Got 393010/405504 with acc 96.92
Dice score: 0.6155987794292284


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:12<00:00,  5.10s/it, loss=0.188]


===>Saving checkpoint
Got 392625/405504 with acc 96.82
Dice score: 0.6097074449812212


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:07<00:00,  4.91s/it, loss=0.151]


===>Saving checkpoint
Got 392795/405504 with acc 96.87
Dice score: 0.6152428008356401


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:13<00:00,  5.12s/it, loss=0.143]


===>Saving checkpoint
Got 392841/405504 with acc 96.88
Dice score: 0.6143923461765685


100%|███████████████████████████████████████████████████████████████████████| 26/26 [02:18<00:00,  5.31s/it, loss=0.12]


===>Saving checkpoint
Got 393374/405504 with acc 97.01
Dice score: 0.6302368828993392


100%|██████████████████████████████████████████████████████████████████████| 26/26 [02:16<00:00,  5.26s/it, loss=0.119]


===>Saving checkpoint
Got 390798/405504 with acc 96.37
Dice score: 0.5496007729282008


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:49<00:00,  4.23s/it, loss=0.225]


===>Saving checkpoint
Got 391979/405504 with acc 96.66
Dice score: 0.5874530425502574


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.80s/it, loss=0.086]


===>Saving checkpoint
Got 391135/405504 with acc 96.46
Dice score: 0.564347063992293


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.76s/it, loss=0.215]


===>Saving checkpoint
Got 391953/405504 with acc 96.66
Dice score: 0.5880860358738987


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.80s/it, loss=0.111]


===>Saving checkpoint
Got 392732/405504 with acc 96.85
Dice score: 0.6055161489941067


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.78s/it, loss=0.0658]


===>Saving checkpoint
Got 392900/405504 with acc 96.89
Dice score: 0.608764208622567


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:10<00:00,  2.71s/it, loss=0.217]


===>Saving checkpoint
Got 393248/405504 with acc 96.98
Dice score: 0.6399539467452584


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:09<00:00,  2.67s/it, loss=0.154]


===>Saving checkpoint
Got 392833/405504 with acc 96.88
Dice score: 0.6174354645277491


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.79s/it, loss=0.0854]


===>Saving checkpoint
Got 391751/405504 with acc 96.61
Dice score: 0.5745814079319702


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.75s/it, loss=0.0641]


===>Saving checkpoint
Got 393045/405504 with acc 96.93
Dice score: 0.6310805703902392


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.77s/it, loss=0.0637]


===>Saving checkpoint
Got 391817/405504 with acc 96.62
Dice score: 0.5899611572811034


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:10<00:00,  2.70s/it, loss=0.0637]


===>Saving checkpoint
Got 392648/405504 with acc 96.83
Dice score: 0.6080955361873521


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.74s/it, loss=0.0633]


===>Saving checkpoint
Got 393077/405504 with acc 96.94
Dice score: 0.6190469405930663


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.78s/it, loss=0.0475]


===>Saving checkpoint
Got 392975/405504 with acc 96.91
Dice score: 0.6215846613636794


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.81s/it, loss=0.102]


===>Saving checkpoint
Got 393202/405504 with acc 96.97
Dice score: 0.6282686395907797


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.82s/it, loss=0.0718]


===>Saving checkpoint
Got 392290/405504 with acc 96.74
Dice score: 0.5967804695120803


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.80s/it, loss=0.0618]


===>Saving checkpoint
Got 393129/405504 with acc 96.95
Dice score: 0.6269542839248496


100%|███████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.80s/it, loss=0.05]


===>Saving checkpoint
Got 393020/405504 with acc 96.92
Dice score: 0.6191780501059659


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.84s/it, loss=0.153]


===>Saving checkpoint
Got 392867/405504 with acc 96.88
Dice score: 0.616342472020848


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.81s/it, loss=0.0876]


===>Saving checkpoint
Got 392773/405504 with acc 96.86
Dice score: 0.6059404302474133


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.84s/it, loss=0.0567]


===>Saving checkpoint
Got 393079/405504 with acc 96.94
Dice score: 0.6065143502988156


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.82s/it, loss=0.424]


===>Saving checkpoint
Got 393054/405504 with acc 96.93
Dice score: 0.6246787800953868


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.75s/it, loss=0.134]


===>Saving checkpoint
Got 392452/405504 with acc 96.78
Dice score: 0.6083451059776027


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.76s/it, loss=0.0554]


===>Saving checkpoint
Got 392077/405504 with acc 96.69
Dice score: 0.5880102800999261


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.79s/it, loss=0.408]


===>Saving checkpoint
Got 393710/405504 with acc 97.09
Dice score: 0.635079527825413


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.77s/it, loss=0.156]


===>Saving checkpoint
Got 392784/405504 with acc 96.86
Dice score: 0.6243475274065701


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:10<00:00,  2.70s/it, loss=0.107]


===>Saving checkpoint
Got 393675/405504 with acc 97.08
Dice score: 0.6403325403879673


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.75s/it, loss=0.103]


===>Saving checkpoint
Got 392408/405504 with acc 96.77
Dice score: 0.5866641451868899


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:10<00:00,  2.72s/it, loss=0.0571]


===>Saving checkpoint
Got 392848/405504 with acc 96.88
Dice score: 0.6079589243215232


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.84s/it, loss=0.0439]


===>Saving checkpoint
Got 393064/405504 with acc 96.93
Dice score: 0.6219824389366667


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:12<00:00,  2.78s/it, loss=0.186]


===>Saving checkpoint
Got 390977/405504 with acc 96.42
Dice score: 0.5651216178776048


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.84s/it, loss=0.0448]


===>Saving checkpoint
Got 392804/405504 with acc 96.87
Dice score: 0.6264207658156062


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:11<00:00,  2.75s/it, loss=0.115]


===>Saving checkpoint
Got 393238/405504 with acc 96.98
Dice score: 0.6434175747221776


100%|███████████████████████████████████████████████████████████████████████| 26/26 [01:14<00:00,  2.88s/it, loss=0.13]


===>Saving checkpoint
Got 392750/405504 with acc 96.85
Dice score: 0.6361426534047696


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:14<00:00,  2.88s/it, loss=0.0401]


===>Saving checkpoint
Got 392509/405504 with acc 96.80
Dice score: 0.6130306725979254


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:15<00:00,  2.92s/it, loss=0.068]


===>Saving checkpoint
Got 392044/405504 with acc 96.68
Dice score: 0.5889213428312507


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:15<00:00,  2.89s/it, loss=0.774]


===>Saving checkpoint
Got 393041/405504 with acc 96.93
Dice score: 0.6162830280425137


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.83s/it, loss=0.193]


===>Saving checkpoint
Got 388900/405504 with acc 95.91
Dice score: 0.5482661306854747


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.82s/it, loss=0.109]


===>Saving checkpoint
Got 393028/405504 with acc 96.92
Dice score: 0.6186646503451885


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:14<00:00,  2.85s/it, loss=0.129]


===>Saving checkpoint
Got 392340/405504 with acc 96.75
Dice score: 0.5872729793260408


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:16<00:00,  2.92s/it, loss=0.363]


===>Saving checkpoint
Got 392953/405504 with acc 96.90
Dice score: 0.6206884137183655


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:15<00:00,  2.90s/it, loss=0.249]


===>Saving checkpoint
Got 393273/405504 with acc 96.98
Dice score: 0.6375874030035362


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:13<00:00,  2.83s/it, loss=0.0604]


===>Saving checkpoint
Got 393142/405504 with acc 96.95
Dice score: 0.6263436867014944


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:14<00:00,  2.87s/it, loss=0.0775]


===>Saving checkpoint
Got 391806/405504 with acc 96.62
Dice score: 0.5755876731733588


100%|█████████████████████████████████████████████████████████████████████| 26/26 [01:16<00:00,  2.92s/it, loss=0.0376]


===>Saving checkpoint
Got 393363/405504 with acc 97.01
Dice score: 0.6330792096114478


100%|██████████████████████████████████████████████████████████████████████| 26/26 [01:17<00:00,  2.99s/it, loss=0.302]


===>Saving checkpoint
Got 393635/405504 with acc 97.07
Dice score: 0.6340698404816861


In [28]:
def batch(iterable, batch_size=16):
    iter_length = len(iterable)
    for idx in range(0, iter_length, batch_size):
        yield iterable[idx:min(idx + batch_size, iter_length)]
        

def test(model, img, mask):
    model.eval()
    windows = view_as_windows(img, (IMAGE_SIZE, IMAGE_SIZE, 3), step=IMAGE_SIZE)
    width, height, _, _, _, _ = windows.shape
    windows = windows.flatten().reshape(width*height, IMAGE_SIZE, IMAGE_SIZE, 3)

    mask_windows = view_as_windows(mask, (IMAGE_SIZE, IMAGE_SIZE), step=IMAGE_SIZE)
    mask_windows = mask_windows.flatten().reshape(width*height, IMAGE_SIZE, IMAGE_SIZE)
    
    all_preds = []
    for x, y in zip(windows, mask_windows):
        x = x.to(device)
        y = y.to(device).unsqueeze(1)
        preds = torch.sigmoid(model(x))
        preds = (preds > 0.5).float()
        all_preds.append(preds)
    
    print(all_preds)
  
    
test_img_filenames = glob.glob()
test_img_filenames.sort()
test_images = [cv2.imread(img) for img in test_img_filenames]

test_mask_filenames = glob.glob()
test_mask_filenames.sort()
test_masks = [cv2.imread(img, 0) for img in test_mask_filenames]

test_ds = BloodVesselDataset(
        image_dir="data/test_images/*",
        mask_dir="data/test_masks/*")
    
test_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory)
    
for img, mask in zip(test_images, test_masks):
    test(model, img, mask)
    break

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

## Bibliography:
https://www.youtube.com/watch?v=IHq1t7NxS8k&ab_channel=AladdinPersson