In [None]:
#import packages
import numpy as np
import random
import torch
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
import torchvision.transforms.functional as tf
from torch.utils.data import DataLoader
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.optim import lr_scheduler
import torchvision
import matplotlib.pyplot as plt

In [None]:
def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True # keep True if all the input have same size.
SEED=42
seed_everything(SEED=SEED)

In [None]:
#customize datatset
class LipTrainDataset(data.Dataset):
    def __init__(self, file_path=None, transform=None,table=None):   
        self.file_path = file_path
        self.transform = transform
        self.table = [1] * 256
        self.table[0] = 0
        # initialize
        self.init_all_data(file_path) 
        return None
    def init_all_data(self, file_path):
        
        # record the path
        self.images = []
        self.lables = []
        #because the whole set is too big, I only number to do a trial first
        number = 1000################len(totallist)
  
        totallist = [line.rstrip('\n') for line in
                        open(os.path.join(self.file_path, 'TrainVal_pose_annotations', 'lip_train_set.csv'), 'r')]
        for idx in range(number):##################################################################################
            tokens = totallist[idx].split(',')
            i = tokens[0].split('.')[0]+'.png'
            image_path = os.path.join(self.file_path, 'TrainVal_images', 'TrainVal_images', 'train_images', tokens[0])
            lable_path = os.path.join(self.file_path, 'TrainVal_parsing_annotations', 'TrainVal_parsing_annotations','TrainVal_parsing_annotations','train_segmentations', i)
            if self.is_valid_image(image_path) and self.is_valid_image(lable_path):###for image and mask all valid
                self.images.append(image_path)
                self.lables.append(lable_path)   
        self.images.sort()
        self.lables.sort()
        return None               
    def is_valid_image(self, img_path):
        try:
            # 若读取成功，设valid为True
            i = Image.open(img_path)
            valid = True
        except:
            # 若读取失败，设valid为False
            valid = False
            
        return valid        

    def __getitem__(self, idx):
       #turn image to rgb
        image = Image.open(self.images[idx]).convert('RGB')
        lable = Image.open(self.lables[idx])
        lable=lable.point(self.table,'1')
        if self.transform:
            image,lable = self.transform(image,lable)

                         
        return image, lable   
    def __len__(self):
        return len(self.images)
    
    
class LipValDataset(data.Dataset):
    def __init__(self, file_path=None, transform=None,transform2=None,table=None):   
        
        self.file_path = file_path
        self.transform = transform
        self.table = [1] * 256
        self.table[0] = 0
        self.init_all_data(file_path) 
        return None
    def init_all_data(self, file_path):
        self.images = []
        self.lables = []
        number=3000######################len(totallist)
        totallist = [line.rstrip('\n') for line in
                        open(os.path.join(self.file_path, 'TrainVal_pose_annotations', 'lip_val_set.csv'), 'r')]
        for idx in range(number):###########################################################################
            tokens = totallist[idx].split(',')
            i = tokens[0].split('.')[0]+'.png'
            image_path = os.path.join(self.file_path, 'TrainVal_images', 'TrainVal_images', 'val_images', tokens[0])
            lable_path = os.path.join(self.file_path, 'TrainVal_parsing_annotations', 'TrainVal_parsing_annotations','TrainVal_parsing_annotations','val_segmentations', i)
            if self.is_valid_image(image_path) and self.is_valid_image(lable_path):
                self.images.append(image_path)
                self.lables.append(lable_path) 
        self.images.sort()
        self.lables.sort()
        return None               
    def is_valid_image(self, img_path):
        
        try:
            i = Image.open(img_path)
            valid = True
        except:
            valid = False
            
        return valid        

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        lable = Image.open(self.lables[idx])
        lable=lable.point(self.table,'1')
        if self.transform:
            image,lable = self.transform(image,lable)          
        return image, lable   
    def __len__(self):
        
        return len(self.images)
    
class LipTestDataset(data.Dataset):
    def __init__(self, file_path=None, transform=None):   
       
        self.file_path = file_path
        self.transform = transform
        self.init_all_data(file_path) 
        return None

    def init_all_data(self, file_path):
        self.images = []
        totallist = [line.rstrip('\n') for line in
                        open(os.path.join(self.file_path, 'Testing_images', 'test_id.txt'), 'r')]
        for idx in range(len(totallist)):
            tokens = totallist[idx]
            i = tokens+'.jpg'
            image_path = os.path.join(self.file_path, 'Testing_images', 'Testing_images','testing_images', i)
            if self.is_valid_image(image_path):
                self.images.append(image_path)       
        return None               
    def is_valid_image(self, img_path):
        
        try:
            i = Image.open(img_path)
            valid = True
        except:
            valid = False
            
        return valid        

    def __getitem__(self, idx):

        image = Image.open(self.images[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)       
        return image 
    def __len__(self):
        
        return len(self.images)

In [None]:
#customize transformation
def mytransform1(image,lable):
    image = tf.resize(image, (224,224))
    lable = tf.resize(lable, (224,224)) 
    table = [1] * 256
    table[0]=0
    lable=lable.point(table,'1')
    
    if random.random() > 0.5:
        image = tf.hflip(image)
        lable = tf.hflip(lable)
    if random.random() > 0.5:
        image = tf.vflip(image)
        lable = tf.vflip(lable)
    image = tf.adjust_contrast(image,random.uniform(0.6,1.5))
    image = tf.adjust_brightness(image,random.uniform(0.6,1.5))
    image = tf.to_tensor(image)
    lable = tf.to_tensor(lable)
    image = tf.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return image, lable
def mytransform2(image,lable):
    image = tf.resize(image, (224,224))
    lable = tf.resize(lable, (224,224)) 
    table = [1] * 256
    table[0]=0
    lable=lable.point(table,'1')
    image = tf.adjust_contrast(image,random.uniform(0.6,1.5))
    image = tf.adjust_brightness(image,random.uniform(0.6,1.5))
    image = tf.to_tensor(image)
    lable = tf.to_tensor(lable)
    image = tf.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return image, lable

In [None]:
traindataset = LipTrainDataset('../input/singleperson',transform1=mytransform1)
train_loader = DataLoader(traindataset, batch_size=16, shuffle=True, num_workers=2)
valdataset = LipValDataset('../input/singleperson',transform1=mytransform2)
val_loader = DataLoader(valdataset, batch_size=16, shuffle=True, num_workers=2)

In [None]:
#Unet
import torch.nn as nn

from math import sqrt
class Double_Conv_Block(nn.Module):
    def __init__(self, input_channel, output_channel):
        super(Double_Conv_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=1, padding=1, bias=True),
            #if there is padding=1,then no crop
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, input_channel): 
        return self.conv(input_channel)

class Up_Conv_Block(nn.Module): #Up sampling
    def __init__(self, input_channel, output_channel):
        super(Up_Conv_Block, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, input_channel):
        
        return self.up(input_channel)


class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        #img_ch=3 when RGB
        #output_ch=1 for our project
        super(U_Net, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = Double_Conv_Block(input_channel=img_ch, output_channel=64)
        self.Conv2 = Double_Conv_Block(input_channel=64, output_channel=128)
        self.Conv3 = Double_Conv_Block(input_channel=128, output_channel=256)
        self.Conv4 = Double_Conv_Block(input_channel=256, output_channel=512)
        self.Conv5 = Double_Conv_Block(input_channel=512, output_channel=1024)

        self.Up5 = Up_Conv_Block(input_channel=1024, output_channel=512)
        self.Up_Conv_Block5 = Double_Conv_Block(input_channel=1024, output_channel=512)

        self.Up4 = Up_Conv_Block(input_channel=512, output_channel=256)
        self.Up_Conv_Block4 = Double_Conv_Block(input_channel=512, output_channel=256)

        self.Up3 = Up_Conv_Block(input_channel=256, output_channel=128)
        self.Up_Conv_Block3 = Double_Conv_Block(input_channel=256, output_channel=128)

        self.Up2 = Up_Conv_Block(input_channel=128, output_channel=64)
        self.Up_Conv_Block2 = Double_Conv_Block(input_channel=128, output_channel=64)

        self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
        
        #initialize weight
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        # downsample 
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # up sample + concat 
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)

        d5 = self.Up_Conv_Block5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_Conv_Block4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_Conv_Block3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_Conv_Block2(d2)

        d1 = self.Conv_1x1(d2)

        return torch.sigmoid(d1)      

In [None]:
model = U_Net().cuda()

In [None]:
#loss
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)      
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):       
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)  
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

In [None]:
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#criterion = nn.MSELoss()
#criterion = DiceLoss(model)
criterion = DiceBCELoss(model)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
def iou(pred,target):
    batch_accs = 0.0
    for img in range(pred.size(0)):
        pred_inds = pred[img,0,:,:] == 1
        target_inds = target[img,0,:,:] == 1
        intersection = (pred_inds[target_inds]).sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        batch_accs+=(float(intersection)/float(max(union,1)))
    return batch_accs

In [None]:
EPOCH=40
train_num = len(traindataset)
val_num = len(valdataset)
losses=[]
val_accs=[]
train_accs=[]
for epoch in range(EPOCH):
    # train
    model.train()
    running_loss = 0.0
    train_acc=[]
    running_acc = 0.0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits = model(images.cuda())
        o = logits.cuda().data.cpu().numpy()
        preds = torch.tensor((o>0.5).astype(np.float32)).to(device)
        loss = criterion(logits, labels.cuda())
        
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        running_acc += iou(preds,labels)
        train_accurate = running_acc / train_num
        train_acc.append(train_accurate)
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()
    train_accs.append(train_acc[-1])


    # validate
    torch.cuda.empty_cache()
    model.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data
            outputs = model(val_images.cuda())  # eval model only have last output layer
            # loss = criterion(outputs, test_labels)
            o = outputs.cuda().data.cpu().numpy()
            preds = torch.tensor((o>0.5).astype(np.float32)).to(device)
            

            last_eval = {'image': val_images, 'lable': val_labels, 'output': outputs, 'pred': preds}
            acc += iou(preds,val_labels)
        val_accurate = acc / val_num
        val_accs.append(val_accurate)
   
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f train_accuracy: %.3f   ' %
              (epoch + 1, running_loss / step, val_accurate,train_accs[-1]))
        losses.append(running_loss / step)

print('Finished Training')

In [None]:
def visualize(dic):
    for k,v in dic.items():
        im = torchvision.utils.make_grid(v[:8,:,:,:], nrow=4)
        im = im.cuda().data.cpu().numpy().transpose((1, 2, 0))
        if k == 'image':
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            im = std * im + mean
        im = np.clip(im, 0, 1)
        plt.imshow(im);       
        plt.title(k)
        plt.pause(0.001)

In [None]:
visualize(last_eval)

In [None]:
x=np.arange(len(losses))
plt.plot(x,losses)
plt.title('DiceBCE Loss ')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [None]:
x=np.arange(len(train_accs))
plt.plot(x,train_accs,label='train')
plt.plot(x,val_accs,label='validate')
plt.title('DiceBCE Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()