In [1]:
import os
import random
from PIL import Image
from random import randint
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import networkx as nx

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
input_test_folder = r'/content/drive/MyDrive/data/deadends/test/input'
label_test_folder = r'/content/drive/MyDrive/data/deadends/test/label'

binary_test_folder = r'/content/drive/MyDrive/data/deadends/test/binary'

In [5]:
# obtaining metadata

INF = 999.0

def calculate_tortuosity(mask: np.ndarray, start: tuple, end: tuple):
    """
    Calculate tortuosity using Dijkstra algorithm.

    Args:
        mask (numpy ndarray): binary mask.
        start (tuple of ints): (x, y) start point coordinates.
        end (tuple of ints): (x, y) end point coordinates.

    Returns:
        float: the tortuosity between the two `start` and `end` points
        None: if there is no conection between them
    """
    G = nx.grid_2d_graph(*mask.shape)
    for (x, y) in list(G.nodes):
        if mask[x, y] == 0:
            G.remove_node((x, y))

    for edge in G.edges:
        G.edges[edge]['weight'] = 1

    try:
        path = nx.shortest_path(G, source=start, target=end, weight='weight')
        length_real = nx.shortest_path_length(G, source=start, target=end, weight='weight') # djikstra algorithm
    except nx.NetworkXNoPath:
        return None
    length_direct = np.linalg.norm(np.array(start) - np.array(end))
    tau = length_real / length_direct
    return tau

# set of all possible coordenates for chossing points

# this work for all images:
def list_points(img_array: np.ndarray):
    """
    Iterate over a numpy 2D array to check for valid points (points where value equals 1.0)

    Args:
        img_array (numpy ndarray): Binary image where the iteration will be performed
    """
    valid_points = []
    for i in range(img_array.shape[0]):
        for j in range(img_array.shape[1]):
            if img_array[i, j] == 1:
                valid_points.append((i, j))
    return valid_points

def iterative_tortuosity(mask: np.ndarray, n: int, valids: list):

    if len(valids) == 0 or len(valids) == 1:
        return 1 # study if 1 is really the best choice

    final_tortuosity = 0.0
    denominator = n
    for i in range(0, n):
        start = valids[randint(0, len(valids)-1)]
        end = valids[randint(0, len(valids)- 1)]
        while (start[0] == end[0] and start[1] == end[1]):
            end = valids[randint(0, len(valids)- 1)]
        tortuosity = calculate_tortuosity(mask, start, end)
        if tortuosity == None:
            denominator -= 1
            continue
        else:
            final_tortuosity += tortuosity
    if denominator <= 0:
        return 1
    return (final_tortuosity/denominator).item()

input_aug_folder = r'/content/data/train/input'
label_aug_folder = r'/content/data/train/label'

def calculate_metadata(folder_path: str):
    metadata = []
    num_imgs = len(os.listdir(folder_path))
    for _ in range(0, num_imgs):
        metadata.append(torch.zeros(3))

    percent20 = False
    percent40 = False
    percent60 = False
    percent80 = False
    i = 0
    for filename in sorted(os.listdir(folder_path), key=lambda x: int(x.split(".")[0])):
        img = np.array(Image.open(os.path.join(folder_path, filename)), dtype=np.float32) / 255
        img = (img >= 0.5).astype(np.float32)

        # calculate porosity
        phi = (np.sum(img == 1)/ (200 * 200)).item()
        metadata[i][0] = phi

        # calculate tortuosity
        pores = list_points(img)
        tau = iterative_tortuosity(img, 20, pores)
        metadata[i][1] = tau

        # pseudo-permeability kozeny-carman equation (without the constant)
        k = (phi**3)/((1 - phi)**2 * tau**2)

        metadata[i][2] = k


        i += 1
        print(i)
        if (i / num_imgs > 0.8 and percent80 == False):
            print("80%")
            percent80 = True
            continue
        if (i / num_imgs > 0.6 and percent60 == False):
            print("60%")
            percent60 = True
            continue
        if (i / num_imgs > 0.4 and percent40 == False):
            print("40%")
            percent40 = True
            continue
        if (i / num_imgs > 0.2 and percent20 == False):
            print("20%")
            percent20 = True
    print("100%")
    return metadata

print("-----test  metadata-----")
test_metadata = calculate_metadata(binary_test_folder)

-----test  metadata-----
1
2
3
4
5
6
7
8
9
10
11
12
13
20%
14
15
16
17
18
19
20
21
22
23
24
25
40%
26
27
28
29
30
31
32
33
34
35
36
37
60%
38
39
40
41
42
43
44
45
46
47
48
49
80%
50
51
52
53
54
55
56
57
58
59
60
100%


In [6]:
# dataset class

class DeadEnds(Dataset):
    def __init__(self, img_dir, mask_dir, vector_data, img_transform=None, mask_transform=None):
        """
        Args:
            img_dir (str): Directory with the input images.
            mask_dir (str): Directory with the corresponding segmentation masks.
            vector_data (list or array): A list (or array) of vectors (each with 3 elements) for each image.
                                         Make sure len(vector_data) == number of images in img_dir.
            img_transform (callable, optional): Optional transform to be applied on the input image.
            mask_transform (callable, optional): Optional transform to be applied on the mask.
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.vector_data = vector_data
        self.img_transform = img_transform
        self.mask_transform = mask_transform

        self.images = sorted(os.listdir(self.img_dir), key=lambda x: int(x.split(".")[0]))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)

        image = Image.open(img_path)
        mask = Image.open(mask_path).convert("L")

        if self.img_transform:
            image = self.img_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        vec_item = self.vector_data[idx]
        if isinstance(vec_item, torch.Tensor):
            vector = vec_item.clone().detach()
        else:
            vector = torch.tensor(vec_item, dtype=torch.float32)

        return image, vector.unsqueeze(0).unsqueeze(0), mask

In [7]:
mask_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor()
])

test_dataset = DeadEnds(binary_test_folder, label_test_folder, test_metadata, mask_transform, mask_transform)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, pin_memory=True)

len(test_dataset)

60

In [8]:
class SelfAttention(nn.Module):
    def __init__(self, semantic):
        super(SelfAttention, self).__init__()

        # overlapping embedding (query, key, value)
        self.query = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)
        self.key = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)
        self.value = nn.Conv2d(in_channels=semantic, out_channels=semantic, kernel_size=3, stride=1, padding=1)

        # normalization constant
        self.normalizer = sqrt(semantic * 4)

        self.flatten = nn.Flatten(2, 3)  # flatten for the attention calculation

    def forward(self, x):
        b, c, h, w = x.size()

        # Apply query, key, and value convolutions
        q = self.flatten(self.query(x))
        k = self.flatten(self.key(x))
        v = self.flatten(self.value(x))

        # Compute scaled dot-product attention
        scaled = torch.bmm(q, k.permute(0, 2, 1)) / self.normalizer

        # Attention output reshaped back into original size
        return torch.bmm(F.softmax(scaled, dim=-1), v).reshape(b, c, h, w)

In [9]:
class DCA(nn.Module):
    def __init__(self, ic, oc):
        super(DCA, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=oc)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels=oc, out_channels=oc, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(oc)

        self.attention = SelfAttention(semantic=oc)

        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.attention(x)
        x = self.relu2(x)
        return x

In [15]:
class ExpansionNet(nn.Module):
    def __init__(self):
        super(ExpansionNet, self).__init__()

        # DCA-Head
        self.dcah = DCA(ic=1, oc=16)

        # Expansion-Head
        self.bn = nn.BatchNorm2d(1)
        self.expansion_transpose = nn.ConvTranspose2d(
                                    in_channels=1,
                                    out_channels=16,
                                    kernel_size=(160, 40),
                                    stride=(1, 60),
                                    padding=(0, 0),
                                    output_padding=(0, 0)
                                )
        self.expansion_attention = SelfAttention(16)

        # encoder
        self.dca1 = DCA(ic=64//2, oc=128//2)
        self.dca2 = DCA(ic=128//2, oc=256//2)
        self.dca3 = DCA(ic=256//2, oc=512//2)
        self.dca4 = DCA(ic=512//2, oc=1024//2)
        self.dca5 = DCA(ic=1024//2, oc=2048//2)

        # bottleneck
        self.bottom_conv = nn.Conv2d(in_channels=2048//2,
                                     out_channels=2048//2,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        self.bottom_norm = nn.BatchNorm2d(num_features=2048//2)
        self.unity_conv = nn.Conv2d(in_channels=2048//2,
                                    out_channels=2048//2,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)

        # decoder
        self.dca6 = DCA(ic = 4096//2, oc=2048//2)

        self.transpose1 = nn.ConvTranspose2d(in_channels=2048//2,
                                             out_channels=1024//2,
                                             kernel_size=2,
                                             stride=2,
                                             padding=0,
                                             output_padding=0)

        self.dca7 = DCA(ic=2048//2, oc=1024//2)

        self.transpose2 = nn.ConvTranspose2d(in_channels=1024//2,
                                             out_channels=512//2,
                                             kernel_size=2,
                                             stride=2,
                                             padding=0,
                                             output_padding=0)

        self.dca8 = DCA(ic=1024//2, oc=512//2)

        self.transpose3 = nn.ConvTranspose2d(in_channels=512//2,
                                             out_channels=256//2,
                                             kernel_size=2,
                                             stride=2, padding=0,
                                             output_padding=0)

        self.dca9 = DCA(ic=512//2, oc=256//2)

        self.transpose4 = nn.ConvTranspose2d(in_channels=256//2,
                                             out_channels=128//2,
                                             kernel_size=2,
                                             stride=2,
                                             padding=0,
                                             output_padding=0)

        self.dca10 = DCA(ic=256//2, oc=128//2)

        self.final_conv = nn.Conv2d(in_channels=128//2,
                                    out_channels=1,
                                    kernel_size=3,
                                    stride=1,
                                    padding=1)

        self.relu = nn.ReLU(inplace=False)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, img, vec):
        img = self.dcah(img)
        expanded = self.relu(self.expansion_attention(self.expansion_transpose(self.bn(vec))))
        x = torch.cat((img, expanded), dim=1)

        # encoder
        enc1 = self.dca1(x)
        enc2 = self.pool(self.dca2(enc1))
        enc3 = self.pool(self.dca3(enc2))
        enc4 = self.pool(self.dca4(enc3))
        enc5 = self.pool(self.dca5(enc4))

        # bottleneck
        bottom1 = self.relu(self.bottom_norm(self.bottom_conv(enc5)))
        bottom2 = self.relu(self.bottom_norm(self.unity_conv(bottom1)))

        # decoder
        dec1 = self.dca6(torch.cat((enc5, bottom2), dim=1))
        dec2 = self.dca7(torch.cat((enc4, self.transpose1(dec1)), dim=1))
        dec3 = self.dca8(torch.cat((enc3, self.transpose2(dec2)), dim=1))
        dec4 = self.dca9(torch.cat((enc2, self.transpose3(dec3)), dim=1))
        dec5 = self.dca10(torch.cat((enc1, self.transpose4(dec4)), dim=1))

        return self.final_conv(dec5)

In [11]:
# define evaluation metrics

def calculate_iou(pred: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> float:
    """
    Calculate Intersection over Union (IoU) for two 1-channel tensors.

    Args:
        pred (torch.Tensor): Predicted binary mask (1-channel, shape HxW or BxHxW).
        target (torch.Tensor): Ground truth binary mask (1-channel, same shape as pred).
        threshold (float): Threshold to binarize predicted mask (default 0.5).

    Returns:
        float: IoU value.
    """
    # Ensure the inputs are binary
    pred = (pred >= threshold).float()  # Binarize predictions
    target = target.float()             # Ensure ground truth is float

    # Compute intersection and union
    intersection = torch.sum(pred * target)
    union = torch.sum(pred + target) - intersection

    # Avoid division by zero
    if union == 0:
        return 1.0 if intersection == 0 else 0.0

    # Compute IoU
    iou = intersection / union
    return iou.item()

def pixel_accuracy(pred_mask, true_mask):
    """
    Compute Pixel Accuracy between two segmentation masks.

    Args:
        pred_mask (np.array): Predicted segmentation mask.
        true_mask (np.array): Ground truth segmentation mask.

    Returns:
        float: Pixel accuracy score.
    """
    correct_pixels = np.equal(pred_mask, true_mask).sum()
    total_pixels = true_mask.size
    return correct_pixels / total_pixels

In [16]:
model_path = r'/content/drive/MyDrive/data/models/deadend-model3.pth'

model = ExpansionNet().to(device)

model.load_state_dict(torch.load(model_path, weights_only=True))

model

ExpansionNet(
  (dcah): DCA(
    (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (attention): SelfAttention(
      (query): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (key): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (value): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (flatten): Flatten(start_dim=2, end_dim=3)
    )
    (relu2): ReLU(inplace=True)
  )
  (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (expansion_transpose): ConvTranspose2d(1, 16, kernel_size=(160, 40), stride=(1, 60))
  (expansion_attention): SelfAttention(
    (query): Conv2d(16

In [18]:
import os
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

save_path = r'/content/drive/MyDrive/data/deadends/test/results3'
os.makedirs(save_path, exist_ok=True)

for idx, (s, v, m) in enumerate(test_dataset):
    model.eval()
    s = s.to(device)
    v = v.to(device)
    pred = model(s.unsqueeze(0), v.unsqueeze(0)).detach().cpu()

    sample = nn.Sigmoid()(s).detach().cpu().squeeze().numpy()
    mask = nn.Sigmoid()(m).detach().cpu().squeeze().numpy()
    pred = (nn.Sigmoid()(pred) >= 0.7).float().detach().cpu().squeeze().numpy()

    sample_img = (sample * 255).astype(np.uint8)
    mask_img = (mask * 255).astype(np.uint8)
    pred_img = (pred * 255).astype(np.uint8)

    Image.fromarray(sample_img).save(os.path.join(save_path, f'{idx}_input_image{idx}.png'))
    Image.fromarray(mask_img).save(os.path.join(save_path, f'{idx}_original_segmentation.png'))
    Image.fromarray(pred_img).save(os.path.join(save_path, f'{idx}_predicted.png'))

print(f'Imagens salvas em {save_path}')


Imagens salvas em /content/drive/MyDrive/data/deadends/test/results3
