In [78]:
import torch
from torch.utils.data import Dataset
import cv2
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss
import math
import os
from datetime import datetime
from time import strftime

In [None]:
class Unet2(nn.Module):
    def __init__(self):
        super().__init__()
        # First layer
        self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
        self.conv2 = nn.Conv2d(64, 16, 3, padding = 1)
        self.batchn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size= (2, 2), stride = 2)
        
        # Second layer
        self.conv3 = nn.Conv2d(16, 64, 3, padding = 1)
        self.conv4 = nn.Conv2d(64, 32, 3, padding = 1)
        self.batchn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size= (2, 2), stride = 2)
        self.drop2 = nn.Dropout(0.5)

        # Third layer
        self.conv5 = nn.Conv2d(32, 32, 3, padding = 1)
        self.conv6 = nn.Conv2d(32, 64, 3, padding = 1)
        self.batchn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(kernel_size= (2, 2), stride = 2)

        # Fourth layer
        self.conv7 = nn.Conv2d(64, 256, 3, padding = 1)
        self.conv8 = nn.Conv2d(256, 128, 3, padding = 1)
        self.batchn4 = nn.BatchNorm2d(128)
        self.pool4 = nn.MaxPool2d(kernel_size= (2, 2), stride = 2)
        self.drop4 = nn.Dropout(0.5)

        # Fifth layer
        self.conv9 = nn.Conv2d(128, 512, 3, padding = 1)
        self.conv10 = nn.Conv2d(512, 256, 3, padding = 1)
        self.batchn5 = nn.BatchNorm2d(256)
        # Upsamling 1
        self.upsample1 = nn.ConvTranspose2d(256, 128, kernel_size=(2,2), stride=2)

        # Sixth layer
        self.conv11 = nn.Conv2d(256, 512, 3, padding = 1)
        self.conv12 = nn.Conv2d(512, 128, 3, padding = 1)
        self.batchn6 = nn.BatchNorm2d(128)
        # Upsamling 2
        self.upsample2 = nn.ConvTranspose2d(128, 64, kernel_size=(2,2), stride=2)

        # Seventh layer
        self.conv13 = nn.Conv2d(128, 256, 3, padding = 1)
        self.conv14 = nn.Conv2d(256, 64, 3, padding = 1)
        self.batchn7 = nn.BatchNorm2d(64)
        # Upsamling 3
        self.upsample3 = nn.ConvTranspose2d(64, 32, kernel_size=(2,2), stride=2)

        # Eighth layer
        self.conv15 = nn.Conv2d(64, 64, 3, padding = 1)
        self.conv16 = nn.Conv2d(64, 32, 3, padding = 1)
        self.batchn8 = nn.BatchNorm2d(32)
        # Upsamling 3
        self.upsample4 = nn.ConvTranspose2d(32, 16, kernel_size=(2,2), stride=2)

        # Ninth layer
        self.conv17 = nn.Conv2d(32, 64, 3, padding = 1)
        self.conv18 = nn.Conv2d(64, 16, 3, padding = 1)
        self.batchn9 = nn.BatchNorm2d(16)

        # Last layer
        self.conv19 = nn.Conv2d(16, 1, 1, padding=0)

        # We choose the activation function ReLU
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        # First layer
        c1 = self.activation(self.batchn1(self.conv2(self.conv1(x))))
        p1 = self.pool1(c1)
#         print('c1', c1.shape)
#         print('p1', p1.shape)
        
        # Second layer
        c2 = self.activation(self.batchn2(self.conv4(self.conv3(p1))))
        p2 = self.drop2(self.pool2(c2))
#         print('c2', c2.shape)
#         print('p2', p2.shape)

        # Third layer
        c3 = self.activation(self.batchn3(self.conv6(self.conv5(p2))))
        p3 = self.pool3(c3) 
#         print('c3', c3.shape)
#         print('p3', p3.shape) 

        # Fourth layer
        c4 = self.activation(self.batchn4(self.conv8(self.conv7(p3))))
        p4 = self.drop4(self.pool3(c4))
#         print('c4', c4.shape)
#         print('p4', p4.shape)

        # Fifth layer
        c5 = self.activation(self.batchn5(self.conv10(self.conv9(p4))))
#         print('c5', c5.shape)
        # Upsampling
        u6 = self.upsample1(c5)
#         print('u6', u6.shape)
        # Adding Skip Connection 
        u6 = torch.cat([u6, c4], dim = 1)
        
        # Sixth layer
        c6 = self.activation(self.batchn6(self.conv12(self.conv11(u6))))
#         print('c6', c6.shape)
        # Upsampling
        u7 = self.upsample2(c6)
        # Adding Skip Connection 
#         print('u7', u7.shape)
        u7 = torch.cat([u7, c3], dim = 1)

        # Seventh layer
        c7 = self.activation(self.batchn7(self.conv14(self.conv13(u7))))
        # Upsampling
        u8 = self.upsample3(c7)
        # Adding Skip Connection 
        u8 = torch.cat([u8, c2], dim = 1)

        # Eighth layer
        c8 = self.activation(self.batchn8(self.conv16(self.conv15(u8))))
        # Upsampling
        u9 = self.upsample4(c8)
        # Adding Skip Connection 
        u9 = torch.cat([u9, c1], dim = 1)

        # Ninth layer
        c9 = self.activation(self.batchn9(self.conv18(self.conv17(u9))))

        # Final output because we need values between 0 and 1
        output = torch.sigmoid((self.conv19(c9)))

        return output
     def crop(self, encFeatures, x):
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        return encFeatures

In [48]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        downChannels = (3, 16, 16, 32, 32, 64, 64, 128, 128, 256, 256)
        upChannels = (256, 128, 128, 64, 64, 32, 32, 16, 16)
        upSamplesChannels = (256, 128, 64, 32, 16)
        finalChannel = 1

        self.downConvs = []
        for i in range(len(downChannels) - 1):
            self.downConvs.append(nn.Conv2d(downChannels[i], downChannels[i+1], (3,3), padding = 1))
        self.maxPool = nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.relu = nn.ReLU()

        self.upConvs = []
        for i in range(len(upChannels) - 1):
            self.upConvs.append(nn.Conv2d(upChannels[i], upChannels[i+1], (3,3), padding = 1))

        self.upSamples = []
        for i in range(len(upSamplesChannels) -1):
            self.upSamples.append(nn.ConvTranspose2d(upSamplesChannels[i], upSamplesChannels[i+1], kernel_size=(2,2), stride=2))

        self.finalConv = nn.Conv2d(upChannels[-1], finalChannel, (3,3), padding = 1)
        
    def forward(self, input):
        output = self.downConvs[0](input)
        output = self.downConvs[1](output)
        intermediates = []
        for i in range(2, len(self.downConvs)):
            if(i%2==0):
                intermediates.append(output)
                output = self.maxPool(output)
            output = self.downConvs[i](output)
            print(output.shape)
        
        for i in range(len(self.upSamples)):
            output = self.upSamples[i](output)
            intermed = self.crop(intermediates[-(i+1)], output)
            output = torch.cat([output, intermed], dim=1)
            output = self.upConvs[2*i](output)
            output = self.upConvs[2*i+1](output)
            print(output.shape)

        output = self.finalConv(output)
        return output
    
    def crop(self, encFeatures, x):
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        return encFeatures

In [32]:
class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, transforms):
        # store the image and mask filepaths, and augmentation
        # transforms
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms = transforms
    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)
    def __getitem__(self, idx):
        # grab the image path from the current index
        imagePath = self.imagePaths[idx]
        # load the image from disk, swap its channels from BGR to RGB,
        # and read the associated mask from disk in grayscale mode
        image = cv2.imread(imagePath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.maskPaths[idx], 0)
        # check to see if we are applying any transformations
        if self.transforms is not None:
            # apply the transformations to both image and its mask
            image = self.transforms(image)
            mask = self.transforms(mask)
        # return a tuple of the image and its mask
        return (image, mask)

In [34]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 

device = torch.device(dev)


In [58]:
def train_test_val_split(imagePath, maskPath):
    # Loading images
    baseImages = os.listdir(imagePath)
    baseMasks = os.listdir(maskPath)
    images = [os.path.join(imagePath, image) for image in baseImages]
    masks = [os.path.join(imagePath, mask) for mask in baseMasks]
    print(len(images))
    
    split = [0.8, 0.1, 0.1]
    splitLength = [math.floor(fraction*len(images)) for fraction in split]
    split = random_split(range(len(images)),splitLength, generator=torch.Generator().manual_seed(8))
    trainImages = [images[i] for i in split[0]]
    trainMasks = [images[i] for i in split[0]]
    testImages = [images[i] for i in split[1]]
    testMasks = [images[i] for i in split[1]]
    valImages = [images[i] for i in split[2]]
    valMasks = [images[i] for i in split[2]]

    trainSize = len(trainImages)
    testSize = len(testImages)
    valSize = len(valImages)

    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((128, 128))])
    trainDs = SegmentationDataset(trainImages, trainMasks, transform)
    testDs = SegmentationDataset(testImages, testMasks, transform)
    valDs = SegmentationDataset(valImages, valMasks, transform)
    
    return trainDs, testDs, valDs

In [60]:
imagePath = "./data/images"
maskPath = "./data/masks"
trainDs, testDs, valDs = train_test_val_split(imagePath, maskPath)
trainSize = len(trainDs)
testSize = len(testDs)
valSize = len(valDs)

4000


In [90]:
def train(model):
    batchSize = 32
    lr = 0.005
    epochs = 40
    
    path = "./models"
    time = datetime.now().strftime("%H-%M-%S")
    folder = os.path.join(path, time)
    os.mkdir(folder)
    checkpoint = 5
    
    trainLoader = DataLoader(trainDs, shuffle=True, batch_size=batchSize)
    testLoader = DataLoader(testDs, shuffle=True, batch_size=batchSize)
    valLoader = DataLoader(valDs, shuffle=True, batch_size=batchSize)

    lossFunc = BCEWithLogitsLoss()
    opt = Adam(model.parameters(), lr=lr)

    trainLosses = []
    valLosses = []

    for epoch in range(epochs):
        print(f"epoch: {epoch}/{epochs}")
        model.train()
        trainLoss = 0
        valLoss = 0
        trainSteps = 0
        testSteps = 0
        for (x,y) in trainLoader:
            print(f"batch: {trainSteps}/{math.floor(trainSize/batchSize)}")
            trainSteps +=1
            pred = model(x)
#             yc = model.crop(y, pred)
            loss = lossFunc(pred, yc)

            opt.zero_grad()
            loss.backward()
            opt.step()

            trainLoss+=loss
            
        with torch.no_grad():
            model.eval()
            for (x, y) in valLoader:
                testSteps+=1
                pred = model(x)
                yc = model.crop(y, pred)
                valLoss += lossFunc(pred, yc)

        trainLosses.append(trainLoss.item()/trainSteps)
        valLosses.append(valLoss.item()/testSteps)
        
        if(epoch%checkpoint == 0):
            torch.save(model, os.path.join(folder, f'epoch-{epoch}.pt'))
        return trainLosses, valLosses

In [91]:
model = Unet2()
train(model)

epoch: 0/40
batch: 0/100
c1 torch.Size([32, 16, 128, 128])
p1 torch.Size([32, 16, 64, 64])
c2 torch.Size([32, 32, 64, 64])
p2 torch.Size([32, 32, 32, 32])
c3 torch.Size([32, 64, 32, 32])
p3 torch.Size([32, 64, 16, 16])
c4 torch.Size([32, 128, 16, 16])
p4 torch.Size([32, 128, 8, 8])
c5 torch.Size([32, 256, 8, 8])
u6 torch.Size([32, 128, 16, 16])
c6 torch.Size([32, 128, 16, 16])
u7 torch.Size([32, 64, 32, 32])


AttributeError: 'Unet2' object has no attribute 'crop'

In [86]:
datetime.now().strftime("%H-%M-%S")

'15-25-10'