In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from torchsummary import summary
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [None]:
def diceLoss(input,target,numOfClasses,eps = 1e-07):
  input = F.softmax(input,dim=1)

  targetOneHot = F.one_hot(target,num_classes = numOfClasses)
  targetOneHot = targetOneHot.permute(0,3,2,1).float()

  intersection = torch.sum(input * targetOneHot, (0,2,3))
  union = torch.sum(input + targetOneHot, (0,2,3))

  dice = (2 * intersection + eps) / (union + eps)
  dice = 1 - dice

  return dice.mean()

In [None]:
colorMap = {
    (60,16,152): 0,
    (132, 41, 246): 1,
    (110, 193, 228): 2,
    (254, 221, 58): 3,
    (226, 169, 41): 4,
    (155, 155, 155): 5
}

indexToColor = torch.tensor([
    [60,16,152],
    [132, 41, 246],
    [110, 193, 228],
    [254, 221, 58],
    [226, 169, 41],
    [155, 155, 155]
])
indexToColor = indexToColor.to(device)

In [None]:
def RgbToLabel(mask):
    mask = np.array(mask)
    labeledMask = np.zeros((mask.shape[1], mask.shape[2]))

    for rgb, id in colorMap.items():
        matches = np.all(mask == np.array(rgb)[:, None, None], axis=0)
        labeledMask[matches] = id
    return torch.from_numpy(labeledMask).long()

In [None]:
class customDataset(Dataset):
    def __init__(self, dataFilePath, trainDataSize, mode, imgTransform = None, maskTransform = None):
        self.imgPaths = []
        self.maskPaths = []
        self.mode = mode
        self.imgTransform = imgTransform
        self.maskTransform = maskTransform
        self.trainDataSize = trainDataSize

        for i in os.listdir(dataFilePath):
          subDir = os.path.join(dataFilePath,i)
          if os.path.isdir(subDir):
            for j in os.listdir(subDir):
              subDir2 = os.path.join(subDir,j)
              if subDir2.endswith('images'):
                for k in os.listdir(subDir2):
                  self.imgPaths.append(os.path.join(subDir2,k))


        for i in os.listdir(dataFilePath):
          subDir = os.path.join(dataFilePath,i)
          if os.path.isdir(subDir):
            for j in os.listdir(subDir):
              subDir2 = os.path.join(subDir,j)
              if subDir2.endswith('masks'):
                for k in os.listdir(subDir2):
                  self.maskPaths.append(os.path.join(subDir2,k))

        self.trainImgPaths = self.imgPaths[:self.trainDataSize]
        self.trainMaskPaths = self.maskPaths[:self.trainDataSize]
        self.testImgPaths = self.imgPaths[self.trainDataSize:]
        self.testMaskPaths = self.maskPaths[self.trainDataSize:]

    def rgbToLabel(self,mask):
        self.mask = np.array(mask)
        self.labeledMask = np.zeros((self.mask.shape[1], self.mask.shape[2]))

        for rgb, id in colorMap.items():
            self.matches = np.all(self.mask == np.array(rgb)[:, None, None], axis=0)
            self.labeledMask[self.matches] = id
        self.labeledMask = np.expand_dims(self.labeledMask,0)
        return torch.from_numpy(self.labeledMask).long()

    def __len__(self):
      if self.mode == 'train':
        return len(self.trainMaskPaths)
      elif self.mode == 'test':
        return len(self.testMaskPaths)
      else:
        return 'Invalid Mode'

    def __getitem__(self,idx):
      if self.mode == 'train':
        imgPath = self.trainImgPaths[idx]
        maskPath = self.trainMaskPaths[idx]
        image = read_image(imgPath)
        mask = read_image(maskPath)
        mask = self.rgbToLabel(mask)
        if self.imgTransform and self.maskTransform is not None:
          image = self.imgTransform(image)
          mask = self.maskTransform(mask)
        mask = mask.squeeze(0)
        return image, mask
      elif self.mode == 'test':
        imgPath = self.testImgPaths[idx]
        maskPath = self.testMaskPaths[idx]
        image = read_image(imgPath)
        mask = read_image(maskPath)
        mask = self.rgbToLabel(mask)
        if self.imgTransform and self.maskTransform is not None:
          image = self.imgTransform(image)
          mask = self.maskTransform(mask)
        mask = mask.squeeze(0)
        return image, mask
      else:
        return 'Invalid Mode'

imgTransforms = v2.Compose([
    v2.Resize((512,512)),
    v2.ToImage(),
    v2.ToDtype(torch.float32)
])
maskTransforms = v2.Compose([
    v2.Resize((512,512),interpolation = v2.InterpolationMode.NEAREST_EXACT),
])

In [None]:
def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        torch.nn.init.kaiming_uniform_(m.weight)

In [None]:
class doubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super().__init__()
    self.network = nn.Sequential(
    nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),
    nn.BatchNorm2d(out_channels),
    nn.ReLU(inplace = True),
    nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
    nn.BatchNorm2d(out_channels),
    nn.ReLU(inplace = True)
    )
    self.network.apply(init_weights)
  def forward(self,x):
    x = self.network(x)
    return x

class downSample(nn.Module):
  def __init__(self,in_features, out_features):
    super().__init__()
    self.conv = doubleConv(in_features,out_features)
    self.pool = nn.MaxPool2d(2,2)

  def forward(self,x):
    bpx = self.conv(x)
    apx = self.pool(bpx)
    return bpx, apx

class upSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.upConv = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
    self.upConv.apply(init_weights)
    self.dConv = doubleConv(in_channels, out_channels)

  def forward(self,x,prev):
    x = self.upConv(x)
    x = torch.cat([x,prev],1)
    x = self.dConv(x)
    return x


In [None]:
class UNet(nn.Module):
  def __init__(self,in_channels,numOfClasses):
    super().__init__()
    self.en1 = downSample(in_channels,64)
    self.en2 = downSample(64,128)
    self.en3 = downSample(128,256)
    self.en4 = downSample(256,512)
    self.en5 = doubleConv(512,1024)

    self.de1 = upSample(1024,512)
    self.de2 = upSample(512,256)
    self.de3 = upSample(256,128)
    self.de4 = upSample(128,64)

    self.out = nn.Conv2d(64,numOfClasses,kernel_size=1)
    self.out.apply(init_weights)

  def forward(self,x):
    self.bpx1, self.apx1 = self.en1(x)
    self.bpx2, self.apx2 = self.en2(self.apx1)
    self.bpx3, self.apx3 = self.en3(self.apx2)
    self.bpx4, self.apx4 = self.en4(self.apx3)
    self.bpx5 = self.en5(self.apx4)

    x = self.de1(self.bpx5,self.bpx4)
    x = self.de2(x, self.bpx3)
    x = self.de3(x, self.bpx2)
    x = self.de4(x, self.bpx1)

    x = self.out(x)
    return x


In [None]:
trainData = customDataset('/content/drive/MyDrive/Datasets/Semantic segmentation dataset',64,'train',imgTransforms, maskTransforms)
trainDataloader = DataLoader(trainData,batch_size = 2)

testData = customDataset('/content/drive/MyDrive/Datasets/Semantic segmentation dataset',64,'test',imgTransforms, maskTransforms)
testDataloader = DataLoader(testData,batch_size = 2)

modelWithCE = UNet(3,6)
modelWithCE.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/Semantic Segmentation Dataset Segmentation/unet_weights_with_ce.pth', map_location=torch.device('cpu' if torch.cuda.is_available() is False else 'cuda:0')))
modelWithCE = modelWithCE.to(device)

modelWithDice = UNet(3,6)
modelWithDice.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/Semantic Segmentation Dataset Segmentation/unet_weights_with_dice.pth', map_location=torch.device('cpu' if torch.cuda.is_available() is False else 'cuda:0')))
modelWithDice = modelWithDice.to(device)

lossFn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(modelWithDice.parameters(),0.0001)

In [None]:
def test(dataloader, model, loss_fn):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y, 6).item()
            correct += ((pred.argmax(1) == y).type(torch.float)/(512*512)).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
def train(dataloader, model, lossFn, optimizer):
  model.train()

  for batch, (X,y) in enumerate(dataloader):
    X = X.to(device)
    y = y.to(device)
    pred = model(X)
    loss = diceLoss(pred,y,6)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    # testAcc = test(testDataloader,model,lossFn)
    if batch % 3 == 0:
      loss, current = loss.item(), batch
      print(f"loss: {loss:>7f}  [{current:>5d}]")

In [46]:
for i in range(100):

  print(f"\nEpoch {i+1}\n-------------------------------")
  print("Training\n")
  train(trainDataloader, modelWithDice, diceLoss, optimizer)
  print("\nTesting\n")
  test(testDataloader, modelWithDice, diceLoss)


Epoch 1
-------------------------------
Training

loss: 0.924233  [    0]
loss: 0.922862  [    3]
loss: 0.781357  [    6]
loss: 0.887690  [    9]
loss: 0.837813  [   12]
loss: 0.827702  [   15]
loss: 0.853504  [   18]
loss: 0.806977  [   21]
loss: 0.840928  [   24]
loss: 0.913424  [   27]
loss: 0.912075  [   30]

Testing

Test Error: 
 Accuracy: 41.9%, Avg loss: 0.861058 


Epoch 2
-------------------------------
Training

loss: 0.923512  [    0]
loss: 0.917149  [    3]
loss: 0.780050  [    6]
loss: 0.888575  [    9]
loss: 0.836093  [   12]
loss: 0.820396  [   15]
loss: 0.851434  [   18]
loss: 0.803047  [   21]
loss: 0.841029  [   24]
loss: 0.911446  [   27]
loss: 0.908437  [   30]

Testing

Test Error: 
 Accuracy: 40.4%, Avg loss: 0.861346 


Epoch 3
-------------------------------
Training

loss: 0.920063  [    0]
loss: 0.917838  [    3]
loss: 0.766878  [    6]
loss: 0.879822  [    9]
loss: 0.834998  [   12]
loss: 0.805014  [   15]
loss: 0.838215  [   18]
loss: 0.807994  [   21]
los

In [None]:
# torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/Semantic Segmentation Dataset Segmentation/unet_weights_with_dice.pth')

In [None]:
test(trainDataloader, modelWithDice, diceLoss)

Test Error: 
 Accuracy: 44.1%, Avg loss: 0.853441 



In [49]:
img = read_image('/content/drive/MyDrive/Datasets/Semantic segmentation dataset/Tile 1/images/image_part_001.jpg')
print(img.size)
img = imgTransforms(img)

rawMask = read_image('/content/drive/MyDrive/Datasets/Semantic segmentation dataset/Tile 1/masks/image_part_001.png')
rawMask = maskTransforms(rawMask)
mask = RgbToLabel(rawMask)
mask = mask.unsqueeze(0)
mask = mask.to(device)

predWithCE = modelWithCE(img.unsqueeze(0).to(device))
resWithCE = predWithCE.argmax(1)
mappedWithCE = indexToColor[resWithCE]

predWithDice = modelWithDice(img.unsqueeze(0).to(device))
resWithDice = predWithDice.argmax(1)
mappedWithDice = indexToColor[resWithDice]

print(f'Dice loss with model trained using CE: {diceLoss(predWithCE, mask.to(device), 6)}')
print(f'Dice loss with model trained using Dice: {diceLoss(predWithDice, mask.to(device), 6)}')
fig, axes = plt.subplots(1, 3, figsize=(10, 5))
axes[0].imshow(mappedWithCE[0].cpu())
axes[0].set_title('Prediction With CE')

axes[1].imshow(rawMask.permute(1,2,0))
axes[1].set_title('Ground Truth')


axes[2].imshow(mappedWithDice[0].cpu())
axes[2].set_title('Prediction With Dice')

plt.show()

<built-in method size of Tensor object at 0x7c49346b8290>


NameError: name 'modelWithCE' is not defined