In [124]:
import pandas as pd
import numpy as np
import os
import json
import cv2
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as utils
import torchvision.transforms.functional as tf
import torch.nn.functional as f
import torchvision
import torchmetrics as metrics
from lightning.pytorch.loggers import TensorBoardLogger
from PIL import Image, ImageColor
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset, random_split
from pathlib import Path

In [112]:
colors = [
    (255, 0, 0),    
    (0, 255, 0),    
    (0, 0, 255),    
    (255, 255, 0),  
    (0, 255, 255),  
    (255, 0, 255),  
    (128, 0, 0),    
    (0, 128, 0),    
    (0, 0, 128),    
    (128, 128, 0),  
    (0, 128, 128),  
    (128, 0, 128),  
    (128, 128, 128),
    (192, 192, 192),
    (255, 165, 0),  
    (210, 105, 30), 
    (255, 69, 0),   
    (0, 128, 128),  
    (139, 0, 0),    
    (0, 139, 0),    
    (0, 0, 139),    
    (255, 215, 0),  
    (0, 255, 0),    
    (0, 0, 255),    
    (255, 20, 147), 
    (0, 128, 0),    
    (0, 0, 128),    
    (255, 140, 0),  
    (128, 0, 128),  
    (128, 128, 0)   
]

# Data preps(System)

In [119]:
dataDir = Path('EBHI-SEG')
# dataDict = {'imagePath' : [],
#             'labelPath' : [],
#             'className' : []}

classNames = [classes for classes in os.listdir(dataDir) if os.path.isdir(os.path.join(dataDir, classes))]
imageList = []
labelList = []
classList = []

for className in classNames:
    classDir = os.path.join(dataDir, className)
    imageDir = os.path.join(classDir, 'image')
    labelDir = os.path.join(classDir, 'label')
    imagePath = [file for file in os.listdir(imageDir)]
    labelPath = [file for file in os.listdir(labelDir)]
    # for img, lbl in zip(imagePath, labelPath):
    #     if str(img).lower() == str(lbl).lower():
    #         imageList.append(os.path.join(imageDir, img))
    #         labelList.append(os.path.join(labelDir, lbl))
    #         classList.append(className)
    for img in imagePath:
        if img in labelPath:
            imageList.append(os.path.join(imageDir, img))
            labelList.append(os.path.join(labelDir, img))
            classList.append(className)

In [120]:
df = pd.DataFrame({'image' : imageList, 'label' : labelList, 'class' : classList})

In [121]:
df

Unnamed: 0,image,label,class
0,EBHI-SEG/Low-grade IN/image/GTXC2014165-2-400-...,EBHI-SEG/Low-grade IN/label/GTXC2014165-2-400-...,Low-grade IN
1,EBHI-SEG/Low-grade IN/image/GTXC2015407-1-400-...,EBHI-SEG/Low-grade IN/label/GTXC2015407-1-400-...,Low-grade IN
2,EBHI-SEG/Low-grade IN/image/GTxc2012481-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2012481-1-400-...,Low-grade IN
3,EBHI-SEG/Low-grade IN/image/GTxc2014132-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2014132-1-400-...,Low-grade IN
4,EBHI-SEG/Low-grade IN/image/GTxc2012967-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2012967-1-400-...,Low-grade IN
...,...,...,...
2221,EBHI-SEG/Serrated adenoma/image/GTDC2102188-2-...,EBHI-SEG/Serrated adenoma/label/GTDC2102188-2-...,Serrated adenoma
2222,EBHI-SEG/Serrated adenoma/image/GTXC2014129-2-...,EBHI-SEG/Serrated adenoma/label/GTXC2014129-2-...,Serrated adenoma
2223,EBHI-SEG/Serrated adenoma/image/GT2016855-1-40...,EBHI-SEG/Serrated adenoma/label/GT2016855-1-40...,Serrated adenoma
2224,EBHI-SEG/Serrated adenoma/image/GTXC2014129-2-...,EBHI-SEG/Serrated adenoma/label/GTXC2014129-2-...,Serrated adenoma


In [48]:
for folder in dataDir.iterdir():
    if folder.is_dir():
        for folder2 in folder.iterdir():
            if folder2.is_dir():
                print(folder.name)
                print(folder2.name)
                print()

Low-grade IN
label
637
Low-grade IN
image
639
Adenocarcinoma
label
795
Adenocarcinoma
image
795
High-grade IN
label
186
High-grade IN
image
186
Normal
label
76
Normal
image
76
Polyp
label
474
Polyp
image
474
Serrated adenoma
label
58
Serrated adenoma
image
58


In [131]:
df.to_csv('imageLabel.csv', index=False)

# Data preps

In [145]:
df = pd.read_csv('imageLabel.csv')
df.drop(df.columns[0], axis=1, inplace=True)
df

Unnamed: 0,image,label,class
0,EBHI-SEG/Low-grade IN/image/GTXC2014165-2-400-...,EBHI-SEG/Low-grade IN/label/GTXC2014165-2-400-...,Low-grade IN
1,EBHI-SEG/Low-grade IN/image/GTXC2015407-1-400-...,EBHI-SEG/Low-grade IN/label/GTXC2015407-1-400-...,Low-grade IN
2,EBHI-SEG/Low-grade IN/image/GTxc2012481-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2012481-1-400-...,Low-grade IN
3,EBHI-SEG/Low-grade IN/image/GTxc2014132-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2014132-1-400-...,Low-grade IN
4,EBHI-SEG/Low-grade IN/image/GTxc2012967-1-400-...,EBHI-SEG/Low-grade IN/label/GTxc2012967-1-400-...,Low-grade IN
...,...,...,...
2221,EBHI-SEG/Serrated adenoma/image/GTDC2102188-2-...,EBHI-SEG/Serrated adenoma/label/GTDC2102188-2-...,Serrated adenoma
2222,EBHI-SEG/Serrated adenoma/image/GTXC2014129-2-...,EBHI-SEG/Serrated adenoma/label/GTXC2014129-2-...,Serrated adenoma
2223,EBHI-SEG/Serrated adenoma/image/GT2016855-1-40...,EBHI-SEG/Serrated adenoma/label/GT2016855-1-40...,Serrated adenoma
2224,EBHI-SEG/Serrated adenoma/image/GTXC2014129-2-...,EBHI-SEG/Serrated adenoma/label/GTXC2014129-2-...,Serrated adenoma


In [180]:
classes = {
    'Normal' : 1,
    'Polyp' : 2,
    'Low-grade IN' : 4,
    'High-grade IN' : 5,
    'Serrated adenoma' : 3,
    'Adenocarcinoma' : 6
}

In [182]:
# img = Image.open(df.iloc[0,1]).convert('L')
# # print(img.shape)
# # print(img)
# # print(np.unique(img))
# imgArr = np.array(img)
# # print(np.unique(imgArr))
# binary = (imgArr > 0).astype(np.uint8) * 255
# print(np.unique(binary))
# # img.show()
# # print(imgArr.shape)

# img = cv2.imread(df.iloc[0,1])
# img = np.array(img)
# binary = (img > 0).astype(np.uint8) * 255
# plt.imshow(binary)

colorMap = np.array(colors, dtype=np.uint8)
img = Image.open(df.iloc[0,0])
lbl = np.array(Image.open(df.iloc[0,1]))
lblBin = (lbl > 0).astype(np.uint8) * classes[df.iloc[0,2]]
# lblBin = Image.fromarray(lblBin).convert('RGB')
coloredBin = Image.fromarray(colorMap[lblBin]).convert('RGB')
overlay = Image.blend(img, coloredBin, alpha=.6)
overlay.show()

In [178]:
df['class'].unique()

array(['Low-grade IN', 'Adenocarcinoma', 'High-grade IN', 'Normal',
       'Polyp', 'Serrated adenoma'], dtype=object)

In [None]:
class customDataset(Dataset):
    def __init__(self, imageList, labelList, classList, classDict, transforms=None):
        assert len(imageList) == len(labelList) and len(imageList) == len(classList) and len(labelList) == len(classList), '3 of the list are not the same length'
        self.imageList = imageList
        self.labelList = labelList
        self.classList = classList
        self.classDict = classDict
        self.transforms = transforms
    def __len__(self):
        return len(self.imageList)
    def __getitem__(self, index):
        image = np.array(Image.open(self.imageList[index]))
        mask = np.array(Image.open(self.labelList[index]))
        mask = (mask > 0).astype(np.uint8) * self.classDict[self.classList[index]]
        if self.transforms:
            transformed = self.transforms(image=image, mask=mask)
            imgAug = transformed['image'].contiguous()
            maskAug = transformed['mask'].contiguous()
            return imgAug, maskAug
        image = torch.as_tensor(image).float().contiguous()
        mask = torch.as_tensor(mask).long().contiguous()
        return image, mask

class dataModule(pl.LightningDataModule):
    def __init__(self, imageList, labelList, classList, classDict, batchSize, trainSize, valSize):
        super().__init__()
        self.imageList = imageList
        self.labelList = labelList
        self.classList = classList
        self.batchSize = batchSize
        assert (trainSize+valSize) >= 1, 'sum of train size and validation size must be equal to 1'
        self.trainSize = trainSize
        self.valSize = valSize
        self.classDict = classDict
    def prepare_data(self):
        #No data download, so we pass this
        pass
    def _getTransform(self,train=True):
        if train:
            return A.Compose([
                A.OneOf([
                    A.VerticalFlip(),
                    A.HorizontalFlip(),
                    A.RandomRotate90(),
                    A.GaussianBlur(),
                    A.RandomContrast()
                ], p=.5),
                ToTensorV2()
            ])
        return A.Compose([
            ToTensorV2()
        ])
    def setup(self, stage: str):
        trainTransform = self._getTransform()
        valTransform = self._getTransform(False)
        # trainDataset = customDataset(imageList=self.imageList, labelList=self.labelList, classList=self.classList, classDict=self.classDict, transforms=)