In [None]:
#cd to the script folder
from google.colab import drive
import os
drive.mount('/content/drive', force_remount=True)
os.chdir('drive/My Drive/U-NetExample-master')#Need to change the address to where your script is
os.listdir('.')

# Initialization

In [2]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch

class TrainDataset(Dataset):
    def __init__(self, img_dir, label_dir):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.name_list = os.listdir(img_dir)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.img_dir, f'{idx+1}.tif')
        img = Image.open(img_path).convert('RGB')  # Convert to grayscale
        img = np.array(img).astype('float32')
        img = np.array(img)
        # Normalize image
        img -= img.mean()
        img /= img.std()
        # img = img[np.newaxis, ...]  # Add channel dimension

        # Load label
        label_path = os.path.join(self.label_dir, f'{idx+1}.tif')
        label = Image.open(label_path).convert('L')
        label = np.array(label).astype('float32') / 255.0
        label = np.where(label > 0.5, 0.0, 1.0)  # Inverting labels as in original code
        # label = label.reshape(-1, 1)  # Reshape to (-1, 1)

        # Convert to tensors
        img_tensor = torch.from_numpy(img).float()
        label_tensor = torch.from_numpy(label).long()

        return img_tensor, label_tensor

# Usage example:
train_dataset = TrainDataset('data/train', 'data/label')
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)


In [3]:
# Initiaization of the model, here we use ResNet18 as the encoder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def up_conv(in_channels, out_channels):
    return nn.ConvTranspose2d(
        in_channels, out_channels, kernel_size=2, stride=2
    )


class UNet(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder_layers = list(self.encoder.children())

        self.block1 = nn.Sequential(*self.encoder_layers[:3])
        self.block2 = nn.Sequential(*self.encoder_layers[3:5])
        self.block3 = self.encoder_layers[5]
        self.block4 = self.encoder_layers[6]
        self.block5 = self.encoder_layers[7]

        self.up_conv6 = up_conv(512, 512)
        self.conv6 = double_conv(512 + 256, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = double_conv(256 + 128, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = double_conv(128 + 64, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = double_conv(64 + 64, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = nn.Conv2d(32, n_classes, kernel_size=1)
        # self.softmax = nn.Softmax(dim=1) 

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        x = self.up_conv6(block5)
        x = torch.cat([x, block4], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = self.conv10(x)
        # x = self.softmax(x)

        return x

# Instantiate the model
model = UNet(n_classes=2)




In [4]:
import torch.optim as optim

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    # number of channels
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.contiguous().view(C, -1)

class MyCriterion(nn.Module):
    def __init__(self):
        super(MyCriterion, self).__init__()
        self.epsilon = 1e-6

    def dice_loss(self, inputs, targets):
        # Assuming inputs are raw logits and have shape [batch_size, num_classes, height, width]
        inputs = F.softmax(inputs, dim=1)
        # Ensure targets are in the expected shape [batch_size, height, width] and are long integers
        targets = targets.squeeze(1)
        targets = F.one_hot(targets, num_classes=inputs.shape[1]).float()
        targets = targets.permute(0, 3, 1, 2)
        prediction = inputs
        prediction = flatten(prediction) #flatten all dimensions except channel/class
        target = flatten(targets)
        target = target.float()

        if prediction.size(0) == 1:
            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
            # put foreground and background voxels in separate channels
            prediction = torch.cat((prediction, 1 - prediction), dim=0)
            target = torch.cat((target, 1 - target), dim=0)
        w_l = target.sum(-1)
        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
        w_l.requires_grad = False

        intersect = (prediction * target).sum(-1)
        intersect = intersect * w_l

        denominator = (prediction + target).sum(-1)
        denominator = (denominator * w_l).clamp(min=self.epsilon)

        return 1 - (2 * (intersect.sum() / denominator.sum()))

    def forward(self, pred, target):
        return self.dice_loss(pred, target)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = MyCriterion()

# For calculating accuracy
def binary_accuracy(preds, targets):
    preds = torch.softmax(preds, dim=1)
    # output = output.view(1, 1, 512, 512)  # Reshape to original image size
    cubes = (preds[:,1,:,:] > 0.5).float()  # Convert probabilities to binary (0 or 1)
    correct = (cubes == targets).float()
    acc = correct.sum() / correct.numel()
    return acc


# Training

In [4]:
from tqdm import tqdm
num_epochs = 15

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_acc = 0

    for imgs, labels in tqdm(train_loader):
        imgs = imgs.to(device)
        imgs = imgs.permute(0, 3, 1, 2)  # Change to NCHW format
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)

        loss = criterion(outputs, labels)
        acc = binary_accuracy(outputs, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    epoch_loss /= len(train_loader)
    epoch_acc /= len(train_loader)

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Save the model checkpoint
    os.makedirs('checkpoints', exist_ok=True)
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), f'./checkpoints/model_epoch{epoch+1}.pth')
    if (epoch+1) == num_epochs:
        torch.save(model.state_dict(), f'./checkpoints/model_latest.pth')


100%|██████████| 250/250 [00:09<00:00, 26.73it/s]


Epoch 1/15, Loss: 0.3767, Accuracy: 0.9490


100%|██████████| 250/250 [00:08<00:00, 30.99it/s]


Epoch 2/15, Loss: 0.0756, Accuracy: 0.9903


100%|██████████| 250/250 [00:08<00:00, 31.21it/s]


Epoch 3/15, Loss: 0.0374, Accuracy: 0.9922


100%|██████████| 250/250 [00:08<00:00, 30.88it/s]


Epoch 4/15, Loss: 0.0322, Accuracy: 0.9927


100%|██████████| 250/250 [00:08<00:00, 30.97it/s]


Epoch 5/15, Loss: 0.0235, Accuracy: 0.9937


100%|██████████| 250/250 [00:08<00:00, 30.93it/s]


Epoch 6/15, Loss: 0.0207, Accuracy: 0.9940


100%|██████████| 250/250 [00:08<00:00, 30.88it/s]


Epoch 7/15, Loss: 0.0190, Accuracy: 0.9945


100%|██████████| 250/250 [00:08<00:00, 30.87it/s]


Epoch 8/15, Loss: 0.0210, Accuracy: 0.9945


100%|██████████| 250/250 [00:08<00:00, 31.01it/s]


Epoch 9/15, Loss: 0.0174, Accuracy: 0.9948


100%|██████████| 250/250 [00:08<00:00, 31.13it/s]


Epoch 10/15, Loss: 0.0147, Accuracy: 0.9956


100%|██████████| 250/250 [00:08<00:00, 30.71it/s]


Epoch 11/15, Loss: 0.0178, Accuracy: 0.9954


100%|██████████| 250/250 [00:08<00:00, 31.01it/s]


Epoch 12/15, Loss: 0.0129, Accuracy: 0.9961


100%|██████████| 250/250 [00:08<00:00, 30.79it/s]


Epoch 13/15, Loss: 0.0122, Accuracy: 0.9963


100%|██████████| 250/250 [00:08<00:00, 30.99it/s]


Epoch 14/15, Loss: 0.0111, Accuracy: 0.9965


100%|██████████| 250/250 [00:08<00:00, 30.88it/s]


Epoch 15/15, Loss: 0.0107, Accuracy: 0.9967


# Validation

In [5]:
class valDataset(Dataset):
    def __init__(self, test_dir):
        self.test_dir = test_dir
        self.name_list = os.listdir(test_dir)

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.test_dir, f'{idx+501}.tif')
        img = Image.open(img_path).convert('RGB')  # Convert to grayscale
        img = np.array(img).astype('float32')
        img = np.array(img)
        # Normalize image
        img -= img.mean()
        img /= img.std()

        img_tensor = torch.from_numpy(img).float()
        return img_tensor

# Load test data
val_dataset = valDataset('data/val')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Load the trained model
model = UNet()
model.load_state_dict(torch.load('./checkpoints/model_latest.pth', map_location=device))
model.to(device)
model.eval()

# Perform predictions
from torchvision.utils import save_image
os.makedirs('data/val_prediction', exist_ok=True)
with torch.no_grad():
    for idx, img in enumerate(tqdm(val_loader)):
        img = img.to(device)
        img = img.permute(0, 3, 1, 2)
        output = model(img)
        output = torch.softmax(output, dim=1)
        # output = output.view(1, 1, 512, 512)  # Reshape to original image size
        output = output.cpu()  # Move to CPU and convert to uint8

        # To obtain binary mask
        output = (output > 0.5).float()  # Convert probabilities to binary (0 or 1)
        output = output * 255.0  # Scale to 0 or 255

        # Save the prediction
        save_image(output, f'data/val_prediction/{idx}.png')
        if (idx+1) % 100 == 0:
            print(f'Done: {idx+1}/{len(val_loader)} images')


100%|██████████| 20/20 [00:00<00:00, 60.58it/s]


# Testing

In [5]:
import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm

# Define input and output directories
In_dir = './data/test/'
Out_dir = './data/prediction/'
os.makedirs(Out_dir, exist_ok=True)
model_name = './checkpoints/model_latest.pth'
normalization = True # Set this based on your needs

# Load the model
model = UNet(n_classes=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('./checkpoints/model_latest.pth', map_location=device))
model.to(device)
model.eval()
print("Network Loaded")

# Set parameters
img_size = 1230 #1146
to_size = 512
step = 200
x_list = [k for k in range(0, img_size - to_size, step)] + [img_size - to_size]
y_list = [k for k in range(0, img_size - to_size, step)] + [img_size - to_size]
n = len(x_list) * len(y_list)
print('Crop to N images:', n)

# Get list of images
nameListTest = os.listdir(In_dir)
output_list = [filename.replace('.tif', '.png') for filename in nameListTest]
for i in tqdm(range(len(nameListTest) - 1, -1, -1)):
    # Load and preprocess the image
    img_path = os.path.join(In_dir, nameListTest[i])
    img = Image.open(img_path).convert('RGB')  # Convert to grayscale
    img_np = np.array(img).astype('float32')

    predictions = torch.zeros((n, to_size, to_size), dtype=torch.float32).to(device)
    img_pred = torch.zeros(img_np.shape[:2], dtype=torch.float32).to(device)
    img_count = torch.zeros(img_np.shape[:2], dtype=torch.float32).to(device)

    # tests_tensor = torch.zeros((n, 1, to_size, to_size), dtype=torch.float32).to(device)

    n_temp = 0
    for x_mark in x_list:
        for y_mark in y_list:
            img_temp = img_np[x_mark:x_mark + to_size, y_mark:y_mark + to_size].copy()

            if normalization:
                img_temp -= img_np.mean()
                img_temp /= img_np.std()
            else:
                img_temp /= 255.0

            # Convert to tensor and add channel dimension
            img_temp_tensor = torch.from_numpy(img_temp).unsqueeze(0)  # Shape: (1, to_size, to_size, 3)
            img_temp_tensor = img_temp_tensor.permute(0, 3, 1, 2)  # Change to NCHW format
            img_temp_tensor = img_temp_tensor.to(device)
            with torch.no_grad():
                predictions[n_temp] = torch.softmax(model(img_temp_tensor), dim=1)[:, 1, :, :]
            n_temp += 1
        predictions = (predictions > 0.5).float()
    # Assemble the predictions back into the image
    n_temp = 0
    for x_mark in x_list:
        for y_mark in y_list:
            img_pred[x_mark:x_mark + to_size, y_mark:y_mark + to_size] += predictions[n_temp]
            img_count[x_mark:x_mark + to_size, y_mark:y_mark + to_size] += 1
            n_temp += 1

    # Normalize the predictions
    img_pred /= img_count
    # img_pred = (img_pred > 0.5).float()
    # Convert to numpy array and save image
    img_pred_np = img_pred.cpu().numpy()

    # Normalize to 0-255 and convert to uint8
    img_pred_np -= img_pred_np.min()
    img_pred_np /= img_pred_np.max()
    img_pred_np = (img_pred_np*255).astype(np.uint8)

    # Save the image
    img_out = Image.fromarray(img_pred_np)
    out_path = os.path.join(Out_dir, output_list[i])
    img_out.save(out_path)
    # print('Image saved:', nameListTest[i])


Network Loaded
Crop to N images: 25


100%|██████████| 306/306 [01:17<00:00,  3.97it/s]
