In [5]:
import os
import random
import numpy as np
import torch
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
!pip install torchinfo
from torchinfo import summary

ModuleNotFoundError: No module named 'cv2'

In [6]:
from google.colab import drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'google.colab'

In [None]:
class PennFudanDataset(object):
    def __init__(self, root, train = True):
        self.root = root
        self.train = train
        self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
        
    def augment(self, image, flipCode):
        # using flip as data augmentation
        flip = cv2.flip(image, flipCode)
        return flip

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
        
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        
        img = img.resize((128, 128))
        mask = mask.resize((128, 128))

        # convert the PIL Image into a numpy array
        img = np.array(img)
        mask = np.array(mask)

        mask[mask>0] = 1
        mask = mask
        if self.train == True:
            flipCode = random.choice([1, 2])
            if flipCode != 2:
                img = self.augment(img, flipCode)
                mask = self.augment(mask, flipCode)
        img = np.transpose(img, (2, 0, 1))

        img = torch.as_tensor(img, dtype=torch.uint8)
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        return img, mask

    def __len__(self):
        return len(self.imgs)

In [None]:

dataset = PennFudanDataset('/content/drive/MyDrive/PennFudanPed')
dataset2 = PennFudanDataset('/content/drive/MyDrive/PennFudanPed',False)

# train test val split
dataset = ConcatDataset([dataset,dataset2])
train_index, test_index = train_test_split(range(len(dataset)), test_size=0.1, random_state=42)
train_index, val_index = train_test_split(train_index, test_size=0.1, random_state=42)
train_sampler = SubsetRandomSampler(train_index)
val_sampler = SubsetRandomSampler(val_index)
test_sampler = SubsetRandomSampler(test_index)



In [None]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8, sampler=train_sampler, num_workers=0)
vaild_loader = torch.utils.data.DataLoader(dataset, batch_size=8, sampler=val_sampler, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=8, sampler=test_sampler, num_workers=0)

In [None]:
# Here I use DICE-Loss which is also called soft dice-loss as my loss function
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [None]:
class conv_block(nn.Module):
    """
    Convolution Block 
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):

        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    # using upsample instead of transpose conv
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class U_Net(nn.Module):

    def __init__(self, in_ch=3, out_ch=1):
        super(U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        
        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        #d1 = self.active(out)

        return out

In [None]:
model = U_Net()
summary(model,input_size = (8,3,128,128))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(model, train_loader, valid_loader, num_epochs, lr):
    best_acc = 0
    model.to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # criterion = DiceLoss()

    train_losses = []
    valid_losses = []
    valid_dices = []

    for epoch in range(num_epochs):
        scheduler.step(epoch)
        lr = scheduler.get_lr()
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)

        model.train()

        train_loss = 0.0

        for images, masks in tqdm(train_loader):
            images = images.to(device)
            
            masks = masks.to(device)
            
            
            
            images = images.float()
            masks =masks.float()

            
            optimizer.zero_grad()
            
            pred = model(images)
            # outputs = model(images)[:,1,:,:].unsqueeze(1)
            loss = loss_func(pred, masks)
            loss.backward()

            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss = train_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        print(f'Train loss: {train_loss:.4f}')

        model.eval()

        valid_loss = 0.0
        valid_dice = 0.0

        with torch.no_grad():
            for images, masks in tqdm(valid_loader):
                images = images.to(device)
                masks = masks.to(device)
                images = images.float()
                masks =masks.float()

                pred = model(images)
            # outputs = model(images)[:,1,:,:].unsqueeze(1)
                loss = loss_func(pred, masks)
                DICE_loss = loss_func(pred,masks)
                valid_loss += loss.item() * images.size(0)
                valid_dice += DICE_loss*images.size(0)

        valid_loss = valid_loss / len(valid_loader.dataset)
        valid_dice = valid_dice / len(valid_loader.dataset)
        valid_losses.append(valid_loss)

        valid_dice = 1-valid_loss
        valid_dices.append(valid_dice)
        
        

        if valid_dice >= best_acc:


            best_acc = valid_dice

            torch.save(model.state_dict(),'test{}.pth'.format(epoch))

        print(f'Valid loss: {valid_loss:.4f} - Dice: {valid_dice:.4f}')

    return train_losses, valid_losses, valid_dices

In [None]:
model = U_Net()
initial_lr = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr) # try SGD
#opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99)
loss_func = DiceLoss()
MAX_STEP = int(1e10)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, MAX_STEP, eta_min=1e-5)
train_losses,valid_losses,valid_dices=train(model, train_loader = train_loader, valid_loader = vaild_loader, num_epochs = 40, lr = 0.001)

In [None]:
x = np.linspace(0,40,40)
plt.plot(x,train_losses,'s-',color = 'r',label="train_losses")
plt.plot(x,valid_losses,'o-',color = 'g',label="valid_losses")
plt.plot(x,valid_dices,'o-',color = 'b',label="valid_dice")
plt.xlabel("epoch number")
plt.ylabel("accuracy and loss")
plt.legend(loc = "best")
plt.show()


Here we can see that the train loss keeps going down, while the valid_loss reduce smoothly during the first 10 epoch and starts to fluctuate later. We can also notice that the val_loss doesn't largly come down after 30 epoch, and even become a little larger, which indicates a overfitting after about 30 epochs. 

In [None]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = U_Net()
net.load_state_dict(torch.load('/content/drive/MyDrive/test19.pth', map_location=device))
net.to(device)
net.eval()
av_dice = 0
test_dice = 0
with torch.no_grad():
  for images, masks in tqdm(test_loader):
    images = images.to(device)
    masks = masks.to(device)
    images = images.float()
    masks =masks.float() 
    pred = model(images)
    DICE_loss = loss_func(pred,masks)
    test_dice += DICE_loss*images.size(0)
test_Dice = test_dice / len(test_loader.dataset)
print('the average dice on testset is {}'.format(1-test_Dice))

In [None]:
image,masks = next(iter(train_loader))
image = image.to(device)
image = image.float()
net.eval()
with torch.no_grad():
  predict = net(image)
for i in range(3):
  img = image[i,:,:,:]
  pred = predict[i,:,:,:]
  mask = masks[i,:,:]
  img = img.squeeze().permute(1,2,0).cpu().numpy()
  img = np.uint8(img)
  pred = torch.sigmoid(pred.squeeze().cpu()).numpy()
  mask = mask.squeeze().numpy()
  pred = np.where(pred>=0.5,1,0)
  plt.figure()
  plt.subplot(1,3,1)
  plt.imshow(img)
  plt.title('the train image{}'.format(i))
  plt.axis('off')
  plt.subplot(1,3,2)
  plt.imshow(mask,cmap='rainbow')
  plt.title('the train mask{}'.format(i))
  plt.axis('off')
  plt.subplot(1,3,3)
  plt.imshow(pred,cmap='rainbow')
  plt.title('the predicted mask{}'.format(i))
  plt.axis('off')
  plt.show()

In [None]:
img,mask = next(iter(test_loader))
img = img.to(device)
img = img.float()
net.eval()
with torch.no_grad():
  pred = net(img)
img = img.squeeze().permute(1,2,0).cpu().numpy()
img = np.uint8(img)
pred = torch.sigmoid(pred.squeeze().cpu()).numpy()
mask = mask.squeeze().numpy()
pred = np.where(pred>=0.5,1,0)
plt.figure()
plt.subplot(1,3,1)
plt.imshow(img)
plt.title('the test image')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(mask,cmap='rainbow')
plt.title('the test mask')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(pred,cmap='rainbow')
plt.title('the predicted mask')
plt.axis('off')
plt.show()

Here we can see that my model has done a good job of segmenting the entire contour shape and position of an image in the test set, except for the person's head， which indicates a quite good generaliztion capability

In [None]:

net.eval()
with torch.no_grad():
  img = Image.open('/content/drive/MyDrive/bettles.jpg')
  img = img.resize((128, 128))
  img = np.array(img)          
  img = np.transpose(img, (2, 0, 1))
  img = torch.as_tensor(img, dtype=torch.uint8).unsqueeze(0)                                       
  pred = net(img.to(device).float())
img = img.squeeze().permute(1,2,0).cpu().numpy()
img = np.uint8(img)                       
pred = torch.sigmoid(pred.squeeze().cpu()).numpy()
pred = np.where(pred>=0.5,1,0)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(img)
plt.title('the out of set image')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(pred,cmap='rainbow')
plt.title('the predicted mask')
plt.axis('off')
plt.show()


My model is stll able to segment the bettles from the background, but the human shapes are not quite smooth and there are some extra part of the mask In conclusion, this model's generalization capability is not so excellent on a out of distribution data.