<a href="https://colab.research.google.com/github/arjunparmar/VIRTUON/blob/main/Prashant/Pytorch/TOM_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp /content/drive/Shareddrives/Virtuon/Pytorch/cp-vton-plus.zip /content/

In [10]:
!unzip -qq cp-vton-plus.zip -d /content/

In [19]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import torch
import time
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models
import os
import os.path as osp
import json

In [5]:
class CPDataset(data.Dataset):
    def __init__(self, stage, all_root="cp-vton-plus", data_path = "data", mode="train", radius=5, img_height=256, img_width=192):
        super(CPDataset, self).__init__()

        self.root = all_root

        self.data_root = osp.join(all_root,data_path)

        self.datamode = mode

        self.stage = stage

        self.data_list = "".join([mode, "_pairs.txt"])

        self.fine_height = img_height

        self.fine_width = img_width

        self.radius = radius

        self.data_path = osp.join(all_root,data_path, mode)
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.transform_1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])

        self.transform_2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5), (0.5, 0.5))
        ])

        self.transform_3 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        im_names = []
        c_names = []

        with open(osp.join(self.data_root, self.data_list), 'r') as f:
            for line in f.readlines():
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = c_names

    def name(self):
        return "CPDataset"
    
    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]
        if self.stage == "GMM":
            c = Image.open(osp.join(self.data_path, 'cloth', c_name))
            cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
        else:
            c = Image.open(osp.join(self.data_path, 'warp-cloth', im_name))
            cm = Image.open(osp.join(self.data_path, 'warp-mask', im_name)).convert('L')
        
        c = self.transform(c)
        cm_array = np.array(cm)
        cm_array = (cm_array >= 128).astype(np.float32)
        cm = torch.from_numpy(cm_array)
        cm.unsqueeze_(0)

        # person image
        im = Image.open(osp.join(self.data_path, 'image', im_name))
        im = self.transform(im)

        
        # LIP labels
        
        # [(0, 0, 0),    # 0=Background
        #  (128, 0, 0),  # 1=Hat
        #  (255, 0, 0),  # 2=Hair
        #  (0, 85, 0),   # 3=Glove
        #  (170, 0, 51),  # 4=SunGlasses
        #  (255, 85, 0),  # 5=UpperClothes
        #  (0, 0, 85),     # 6=Dress
        #  (0, 119, 221),  # 7=Coat
        #  (85, 85, 0),    # 8=Socks
        #  (0, 85, 85),    # 9=Pants
        #  (85, 51, 0),    # 10=Jumpsuits
        #  (52, 86, 128),  # 11=Scarf
        #  (0, 128, 0),    # 12=Skirt
        #  (0, 0, 255),    # 13=Face
        #  (51, 170, 221),  # 14=LeftArm
        #  (0, 255, 255),   # 15=RightArm
        #  (85, 255, 170),  # 16=LeftLeg
        #  (170, 255, 85),  # 17=RightLeg
        #  (255, 255, 0),   # 18=LeftShoe
        #  (255, 170, 0)    # 19=RightShoe
        #  (170, 170, 50)   # 20=Skin/Neck/Chest (Newly added after running dataset_neck_skin_correction.py)
        #  ]
         
        # load parsing image
        parse_name = im_name.replace('.jpg', '.png')
        im_parse = Image.open(osp.join(self.data_path, 'image-parse-new',parse_name)).convert('L')
        parse_array = np.array(im_parse)

        im_mask = Image.open(osp.join(self.data_path, 'image-mask', parse_name)).convert('L')
        mask_array = np.array(im_mask)

        parse_shape = (mask_array > 0).astype(np.float32)

        if self.stage == 'GMM':
            parse_head = (parse_array == 1).astype(np.float32) + (parse_array == 4).astype(np.float32) + (parse_array == 13).astype(np.float32)

        else:
            parse_head = (parse_array == 1).astype(np.float32) + (parse_array == 2).astype(np.float32) + (parse_array == 4).astype(np.float32) + (parse_array == 9).astype(np.float32) + (parse_array == 12).astype(np.float32) + (parse_array == 13).astype(np.float32) + (parse_array == 16).astype(np.float32) + (parse_array == 17).astype(np.float32)  
            
        parse_cloth = (parse_array == 5).astype(np.float32) + (parse_array == 6).astype(np.float32) + (parse_array == 7).astype(np.float32)

        parse_shape_ori = Image.fromarray((parse_shape*255).astype(np.uint8))

        parse_shape = parse_shape_ori.resize((self.fine_width//16, self.fine_height//16), Image.BILINEAR)

        parse_shape = parse_shape.resize((self.fine_width, self.fine_height), Image.BILINEAR)
        
        parse_shape_ori = parse_shape_ori.resize((self.fine_width, self.fine_height), Image.BILINEAR)
        
        shape_ori = self.transform_1(parse_shape_ori)

        shape = self.transform_1(parse_shape)

        phead = torch.from_numpy(parse_head)

        pcm = torch.from_numpy(parse_cloth)

        # Upper Cloth
        im_c = im*pcm + (1 - pcm)
        im_h = im*phead + (1-phead)

        # load pose points
        pose_name = im_name.replace('.jpg', '_keypoints.json')
        with open(osp.join(self.data_path, 'pose', pose_name), 'r') as f:
            pose_label = json.load(f)
            pose_data = pose_label['people'][0]['pose_keypoints']
            pose_data = np.array(pose_data)
            pose_data = pose_data.reshape([-1,3])
        
        point_num = pose_data.shape[0]
        pose_map = torch.zeros(point_num, self.fine_height, self.fine_width)
        
        r = self.radius
        
        im_pose = Image.new('L', (self.fine_width, self.fine_height))
        pose_draw = ImageDraw.Draw(im_pose)

        for i in range(point_num):
            one_map = Image.new('L', (self.fine_width, self.fine_height))
            draw = ImageDraw.Draw(one_map)
            pointx = pose_data[i, 0]
            pointy = pose_data[i, 1]

            if pointx > 1 and pointy > 1:
                draw.rectangle((pointx - r, pointy - r, pointx + r, pointy + r), 'white', 'white')
                pose_draw.rectangle((pointx - r, pointy - r, pointx + r, pointy + r), 'white', 'white')

            one_map = self.transform_1(one_map)
            pose_map[i] = one_map[0]

        im_pose = self.transform_1(im_pose)

        agnostic = torch.cat([shape, im_h, pose_map], 0)

        if self.stage == 'GMM':
            im_g = Image.open(osp.join(self.root, 'grid.png'))
            im_g = self.transform(im_g)
        else:
            im_g = ''
        
        pcm.unsqueeze_(0)
        
        result = {
            'c_name': c_name,
            'im_name': im_name,
            'cloth': c,
            'cloth_mask': cm,
            'image': im,
            'agnostic': agnostic,
            'parse_cloth': im_c,
            'shape': shape,
            'head': im_h,
            'pose_image': im_pose,
            'grid_image': im_g,
            'parse_cloth_mask': pcm,
            'shape_ori': shape_ori,
        }

        return result
    def __len__(self):
        return len(self.im_names)


class CPDataLoader(object):
    def __init__(self, dataset, shuffle=True, batch=4, workers=4):
        super(CPDataLoader, self).__init__()

        if shuffle:
            train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
        else:
            train_sampler = None
        
        self.data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch, shuffle=(train_sampler is None),
            num_workers=workers, pin_memory=True, sampler=train_sampler
        )
        self.dataset = dataset
        self.data_iter = self.data_loader.__iter__()

    def next_batch(self):
        try:
            batch = self.data_iter.__next__()
        except StopIteration:
            self.data_iter = self.data_loader.__iter__()
            batch = self.data_iter.__next__()
        
        return batch


In [6]:
class Vgg19(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class VGGLoss(nn.Module):
    def __init__(self, layids=None):
        super(VGGLoss, self).__init__()
        self.vgg = Vgg19()
        self.vgg.cuda()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        self.layids = layids

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        if self.layids is None:
            self.layids = list(range(len(x_vgg)))
        for i in self.layids:
            loss += self.weights[i] * \
                self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss



In [7]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(
            ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(
                ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(
            ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(
            ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(
            ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(
            output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        return self.model(input)

In [8]:
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        use_bias = norm_layer == nn.InstanceNorm2d

        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_nc * 2, outer_nc,
                               kernel_size=3, stride=1, padding=1, bias=use_bias)
            down = [downconv]
            up = [uprelu, upsample, upconv, upnorm]
            model = down + [submodule] + up
        elif innermost:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3,
                               stride=1, padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upsample, upconv, upnorm]
            model = down + up
        else:
            upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3,
                               stride=1, padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upsample, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)


In [17]:
def train_tom(train_loader, model):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.0001, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 -
                                                  max(0, step - 100000) / float(100000 + 1))

    for step in range(200000):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        pcm = inputs['parse_cloth_mask'].cuda()

        # outputs = model(torch.cat([agnostic, c], 1))  # CP-VTON
        outputs = model(torch.cat([agnostic, c, cm], 1))  # CP-VTON+
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        """visuals = [[im_h, shape, im_pose],
                   [c, cm*2-1, m_composite*2-1],
                   [p_rendered, p_tryon, im]]"""  # CP-VTON

        visuals = [[im_h, shape, im_pose],
                   [c, pcm*2-1, m_composite*2-1],
                   [p_rendered, p_tryon, im]]  # CP-VTON+

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        # loss_mask = criterionMask(m_composite, cm)  # CP-VTON
        loss_mask = criterionMask(m_composite, pcm)  # CP-VTON+
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if (step+1) % opt.display_count == 0:
        #     board_add_images(board, 'combine', visuals, step+1)
        #     board.add_scalar('metric', loss.item(), step+1)
        #     board.add_scalar('L1', loss_l1.item(), step+1)
        #     board.add_scalar('VGG', loss_vgg.item(), step+1)
        #     board.add_scalar('MaskL1', loss_mask.item(), step+1)
        #     t = time.time() - iter_start_time
        #     print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f'
        #           % (step+1, t, loss.item(), loss_l1.item(),
        #              loss_vgg.item(), loss_mask.item()), flush=True)

        # if (step+1) % opt.save_count == 0:
        #     save_checkpoint(model, os.path.join(
        #         opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))

In [None]:
def main():
  
    # create dataset
    train_dataset = CPDataset("TOM","cp-vton-plus")

    # create dataloader
    train_loader = CPDataLoader(train_dataset)

    # # visualization
    # if not os.path.exists(opt.tensorboard_dir):
    #     os.makedirs(opt.tensorboard_dir)
    # board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
        # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON
    model = UnetGenerator( 26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)  # CP-VTON+
    train_tom(train_loader, model)


if __name__ == "__main__":
    main()