In [1]:
import torch
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import torch.optim as optim


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x
    

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [4]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride=2),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(U_Net, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)

        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)

        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = self.pad_and_concat(x4, d5)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = self.pad_and_concat(x3, d4)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = self.pad_and_concat(x2, d3)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = self.pad_and_concat(x1, d2)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

    def pad_and_concat(self, x1, x2):
        diffY = x1.size()[2] - x2.size()[2]
        diffX = x1.size()[3] - x2.size()[3]

        x2 = nn.functional.pad(x2, (diffX // 2, diffX - diffX // 2,
                                    diffY // 2, diffY - diffY // 2))
        return torch.cat([x1, x2], dim=1)

In [8]:
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]
        self.mask_names = [f for f in os.listdir(mask_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]
        self.transform = transform

        # Debug print statements
        print(f"Found {len(self.image_names)} images in {image_dir}")
        print(f"Found {len(self.mask_names)} masks in {mask_dir}")
        
        if len(self.image_names) == 0:
            raise ValueError(f"No images found in {image_dir}. Please check the directory path and file extensions.")
        if len(self.mask_names) == 0:
            raise ValueError(f"No masks found in {mask_dir}. Please check the directory path and file extensions.")
        if len(self.image_names) != len(self.mask_names):
            print(f"Number of images: {len(self.image_names)}")
            print(f"Number of masks: {len(self.mask_names)}")  
            raise ValueError("Number of images and masks do not match.")



    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(img_name).convert('L')  # Convert to grayscale
         # Debug print statement
        # print(f"Accessing index: {idx}")

        mask_name = os.path.join(self.mask_dir, self.mask_names[idx])
        mask = Image.open(mask_name).convert('L')  # Convert to grayscale

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

def min_max_normalize(tensor):
        min_val = tensor.min()
        max_val = tensor.max()
        tensor = (tensor - min_val) / (max_val - min_val)
        return tensor

# Transform for resizing and Min-Max normalization
transform = transforms.Compose([
    transforms.Resize((696, 520)),  # Resize images to 696x520
    transforms.ToTensor(),  # Convert PIL image to Tensor
    min_max_normalize,  # Apply Min-Max normalization
])

# Paths to your image and mask folders
image_dir = 'Data/train/01'
mask_dir = 'Data/train/01_MASKS'

# Datasets and DataloadersD
train_dataset = CustomImageDataset(image_dir=image_dir, mask_dir=mask_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Paths to your image and mask folders for validation(using test)
val_image_dir = 'Data/test/02'
val_mask_dir = 'Data/test/02_MASKS'

# Datasets and Dataloaders for test
val_dataset = CustomImageDataset(image_dir=val_image_dir, mask_dir=val_mask_dir, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


Found 84 images in Data/train/01
Found 84 masks in Data/train/01_MASKS
Found 84 images in Data/test/02
Found 84 masks in Data/test/02_MASKS


In [6]:
def get_accuracy(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)
    corr = torch.sum(SR==GT)
    tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
    acc = float(corr)/float(tensor_size)

    return acc

def get_sensitivity(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FN : False Negative
    TP = ((SR==1)+(GT==1))==2
    FN = ((SR==0)+(GT==1))==2

    SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)     
    
    return SE

def get_specificity(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TN : True Negative
    # FP : False Positive
    TN = ((SR==0)+(GT==0))==2
    FP = ((SR==1)+(GT==0))==2

    SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)
    
    return SP

def get_precision(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FP : False Positive
    TP = ((SR==1)+(GT==1))==2
    FP = ((SR==1)+(GT==0))==2

    PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)

    return PC

def get_F1(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SE = get_sensitivity(SR,GT,threshold=threshold)
    PC = get_precision(SR,GT,threshold=threshold)

    F1 = 2*SE*PC/(SE+PC + 1e-6)

    return F1

def get_JS(SR,GT,threshold=0.5):
    # JS : Jaccard similarity
    SR = SR > threshold
    GT = GT == torch.max(GT)
    
    Inter = torch.sum((SR+GT)==2)
    Union = torch.sum((SR+GT)>=1)
    
    JS = float(Inter)/(float(Union) + 1e-6)
    
    return JS

def get_DC(SR,GT,threshold=0.5):
    # DC : Dice Coefficient
    SR = SR > threshold
    GT = GT == torch.max(GT)

    Inter = torch.sum((SR+GT)==2)
    DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)

    return DC

In [9]:
unet = U_Net(img_ch=1,output_ch=1)
lr=1e-4
optimizer = optim.Adam(unet.parameters(), lr=lr)
criterion = torch.nn.BCELoss();
num_epochs = 1

for epoch in range(num_epochs):
    epoch_loss = 0
    acc = 0.	# Accuracy
    SE = 0.		# Sensitivity (Recall)
    SP = 0.		# Specificity
    PC = 0. 	# Precision
    F1 = 0.		# F1 Score
    JS = 0.		# Jaccard Similarity
    DC = 0.		# Dice Coefficient
    length = 0

    for i, (images, GT) in enumerate(train_loader):
	    # GT : Ground Truth
	    # images = images.to(device)
	    # GT = GT.to(device)
	    # SR : Segmentation Result
        SR = unet(images)
        SR_probs = F.sigmoid(SR)
        SR_flat = SR_probs.view(SR_probs.size(0),-1)
        GT_flat = GT.view(GT.size(0),-1)
        loss = criterion(SR_flat,GT_flat)
        epoch_loss += loss.item()
	    # Backprop + optimize
        unet.zero_grad()
        loss.backward()
        optimizer.step()
        acc += get_accuracy(SR,GT)
        SE += get_sensitivity(SR,GT)
        SP += get_specificity(SR,GT)
        PC += get_precision(SR,GT)
        F1 += get_F1(SR,GT)
        JS += get_JS(SR,GT)
        DC += get_DC(SR,GT)
        length += images.size(0)
    cc = acc/length
    E = SE/length
    P = SP/length
    C = PC/length
    F1 = F1/length
    S = JS/length
    DC = DC/length
# Print the log info
    print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
	  epoch+1, num_epochs, \
	  epoch_loss,\
	  acc,SE,SP,PC,F1,JS,DC))

# Decay learning rate
    # if (epoch+1) > (num_epochs - num_epochs_decay):
    #     lr -= (lr / float(num_epochs_decay))
    #     for param_group in optimizer.param_groups:
    #         param_group['lr'] = lr
    #     print ('Decay learning rate to lr: {}.'.format(lr))


: 